
<!DOCTYPE html>

<html lang="en">
  <head>
    <meta charset="utf-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
    <title>tigramite.causal_effects &#8212; Tigramite 5.2 documentation</title>
    <link rel="stylesheet" type="text/css" href="../../_static/pygments.css" />
    <link rel="stylesheet" type="text/css" href="../../_static/alabaster.css" />
    <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
    <script src="../../_static/jquery.js"></script>
    <script src="../../_static/underscore.js"></script>
    <script src="../../_static/_sphinx_javascript_frameworks_compat.js"></script>
    <script src="../../_static/doctools.js"></script>
    <link rel="index" title="Index" href="../../genindex.html" />
    <link rel="search" title="Search" href="../../search.html" />
   
  <link rel="stylesheet" href="../../_static/custom.css" type="text/css" />
  
  
  <meta name="viewport" content="width=device-width, initial-scale=0.9, maximum-scale=0.9" />

  </head><body>
  

    <div class="document">
      <div class="documentwrapper">
        <div class="bodywrapper">
          

          <div class="body" role="main">
            
  <h1>Source code for tigramite.causal_effects</h1><div class="highlight"><pre>
<span></span><span class="sd">&quot;&quot;&quot;Tigramite causal inference for time series.&quot;&quot;&quot;</span>

<span class="c1"># Author: Jakob Runge &lt;jakob@jakob-runge.com&gt;</span>
<span class="c1">#</span>
<span class="c1"># License: GNU General Public License v3.0</span>

<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">import</span> <span class="nn">math</span>
<span class="kn">import</span> <span class="nn">itertools</span>
<span class="kn">from</span> <span class="nn">copy</span> <span class="kn">import</span> <span class="n">deepcopy</span>
<span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">defaultdict</span>
<span class="kn">from</span> <span class="nn">tigramite.models</span> <span class="kn">import</span> <span class="n">Models</span>
<span class="kn">import</span> <span class="nn">struct</span>

<div class="viewcode-block" id="CausalEffects"><a class="viewcode-back" href="../../index.html#tigramite.causal_effects.CausalEffects">[docs]</a><span class="k">class</span> <span class="nc">CausalEffects</span><span class="p">():</span>
<span class="w">    </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;Causal effect estimation.</span>

<span class="sd">    Methods for the estimation of linear or non-parametric causal effects </span>
<span class="sd">    between (potentially multivariate) X and Y (potentially conditional </span>
<span class="sd">    on S) by (generalized) backdoor adjustment. Various graph types are </span>
<span class="sd">    supported, also including hidden variables.</span>
<span class="sd">    </span>
<span class="sd">    Linear and non-parametric estimators are based on sklearn. For the </span>
<span class="sd">    linear case without hidden variables also an efficient estimation </span>
<span class="sd">    based on Wright&#39;s path coefficients is available. This estimator </span>
<span class="sd">    also allows to estimate mediation effects.</span>

<span class="sd">    See the corresponding paper [6]_ and tigramite tutorial for an </span>
<span class="sd">    in-depth introduction. </span>

<span class="sd">    References</span>
<span class="sd">    ----------</span>

<span class="sd">    .. [6] J. Runge, Necessary and sufficient graphical conditions for</span>
<span class="sd">           optimal adjustment sets in causal graphical models with </span>
<span class="sd">           hidden variables, Advances in Neural Information Processing</span>
<span class="sd">           Systems, 2021, 34 </span>
<span class="sd">           https://proceedings.neurips.cc/paper/2021/hash/8485ae387a981d783f8764e508151cd9-Abstract.html</span>


<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    graph : array of either shape [N, N], [N, N, tau_max+1], or [N, N, tau_max+1, tau_max+1]</span>
<span class="sd">        Different graph types are supported, see tutorial.</span>
<span class="sd">    X : list of tuples</span>
<span class="sd">        List of tuples [(i, -tau), ...] containing cause variables.</span>
<span class="sd">    Y : list of tuples</span>
<span class="sd">        List of tuples [(j, 0), ...] containing effect variables.</span>
<span class="sd">    S : list of tuples</span>
<span class="sd">        List of tuples [(i, -tau), ...] containing conditioned variables.  </span>
<span class="sd">    graph_type : str</span>
<span class="sd">        Type of graph.</span>
<span class="sd">    hidden_variables : list of tuples</span>
<span class="sd">        Hidden variables in format [(i, -tau), ...]. The internal graph is </span>
<span class="sd">        constructed by a latent projection.</span>
<span class="sd">    check_SM_overlap : bool</span>
<span class="sd">        Whether to check whether S overlaps with M.</span>
<span class="sd">    verbosity : int, optional (default: 0)</span>
<span class="sd">        Level of verbosity.</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
                 <span class="n">graph</span><span class="p">,</span>
                 <span class="n">graph_type</span><span class="p">,</span>
                 <span class="n">X</span><span class="p">,</span>
                 <span class="n">Y</span><span class="p">,</span>
                 <span class="n">S</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
                 <span class="n">hidden_variables</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
                 <span class="n">check_SM_overlap</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
                 <span class="n">verbosity</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span>
        
        <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">=</span> <span class="n">verbosity</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">N</span> <span class="o">=</span> <span class="n">graph</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>

        <span class="k">if</span> <span class="n">S</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">S</span> <span class="o">=</span> <span class="p">[]</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">listX</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">listY</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">Y</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">listS</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">S</span><span class="p">)</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">X</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">Y</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">Y</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">S</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">S</span><span class="p">)</span>    

        <span class="c1"># </span>
        <span class="c1"># Checks regarding graph type</span>
        <span class="c1">#</span>
        <span class="n">supported_graphs</span> <span class="o">=</span> <span class="p">[</span><span class="s1">&#39;dag&#39;</span><span class="p">,</span> 
                            <span class="s1">&#39;admg&#39;</span><span class="p">,</span>
                            <span class="s1">&#39;tsg_dag&#39;</span><span class="p">,</span>
                            <span class="s1">&#39;tsg_admg&#39;</span><span class="p">,</span>
                            <span class="s1">&#39;stationary_dag&#39;</span><span class="p">,</span>
                            <span class="s1">&#39;stationary_admg&#39;</span><span class="p">,</span>

                            <span class="c1"># &#39;mag&#39;,</span>
                            <span class="c1"># &#39;tsg_mag&#39;,</span>
                            <span class="c1"># &#39;stationary_mag&#39;,</span>
                            <span class="c1"># &#39;pag&#39;,</span>
                            <span class="c1"># &#39;tsg_pag&#39;,</span>
                            <span class="c1"># &#39;stationary_pag&#39;,</span>
                            <span class="p">]</span>
        <span class="k">if</span> <span class="n">graph_type</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">supported_graphs</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Only graph types </span><span class="si">%s</span><span class="s2"> supported!&quot;</span> <span class="o">%</span><span class="n">supported_graphs</span><span class="p">)</span>

        <span class="c1"># TODO?: check that masking aligns with hidden samples in variables</span>
        <span class="k">if</span> <span class="n">hidden_variables</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">hidden_variables</span> <span class="o">=</span> <span class="p">[]</span>
        
        <span class="bp">self</span><span class="o">.</span><span class="n">hidden_variables</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">hidden_variables</span><span class="p">)</span>
        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">hidden_variables</span><span class="o">.</span><span class="n">intersection</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="p">)</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">S</span><span class="p">)))</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;XYS overlaps with hidden_variables!&quot;</span><span class="p">)</span>

        <span class="c1"># Only needed for later extension to MAG/PAGs</span>
        <span class="k">if</span> <span class="s1">&#39;pag&#39;</span> <span class="ow">in</span> <span class="n">graph_type</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">possible</span> <span class="o">=</span> <span class="kc">True</span> 
            <span class="bp">self</span><span class="o">.</span><span class="n">definite_status</span> <span class="o">=</span> <span class="kc">True</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">possible</span> <span class="o">=</span> <span class="kc">False</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">definite_status</span> <span class="o">=</span> <span class="kc">False</span>

        <span class="c1"># Not needed for now...</span>
        <span class="c1"># self.ignore_time_bounds = False</span>

        <span class="c1"># Construct internal graph from input graph depending on graph type</span>
        <span class="c1"># and hidden variables</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">_construct_graph</span><span class="p">(</span><span class="n">graph</span><span class="o">=</span><span class="n">graph</span><span class="p">,</span> <span class="n">graph_type</span><span class="o">=</span><span class="n">graph_type</span><span class="p">,</span>
                              <span class="n">hidden_variables</span><span class="o">=</span><span class="n">hidden_variables</span><span class="p">)</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">_check_graph</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">graph</span><span class="p">)</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">_check_XYS</span><span class="p">()</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">ancX</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_ancestors</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">ancY</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_ancestors</span><span class="p">(</span><span class="n">Y</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">ancS</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_ancestors</span><span class="p">(</span><span class="n">S</span><span class="p">)</span>

        <span class="c1"># If X is not in anc(Y), then no causal link exists</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ancY</span><span class="o">.</span><span class="n">intersection</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">X</span><span class="p">))</span> <span class="o">==</span> <span class="nb">set</span><span class="p">():</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">no_causal_path</span> <span class="o">=</span> <span class="kc">True</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;No causal path from X to Y exists.&quot;</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">no_causal_path</span> <span class="o">=</span> <span class="kc">False</span>

        <span class="c1"># Get mediators</span>
        <span class="n">mediators</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_mediators</span><span class="p">(</span><span class="n">start</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">,</span> <span class="n">end</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="p">)</span> 

        <span class="n">M</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">mediators</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">M</span> <span class="o">=</span> <span class="n">M</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">listM</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">M</span><span class="p">)</span>

        <span class="k">for</span> <span class="n">varlag</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="p">)</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">S</span><span class="p">):</span>
            <span class="k">if</span> <span class="nb">abs</span><span class="p">(</span><span class="n">varlag</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span><span class="p">:</span>
                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;X, Y, S must have time lags inside graph.&quot;</span><span class="p">)</span>

        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="o">.</span><span class="n">intersection</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="p">))</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Overlap between X and Y.&quot;</span><span class="p">)</span>

        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">S</span><span class="o">.</span><span class="n">intersection</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">)))</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Conditions S overlap with X or Y.&quot;</span><span class="p">)</span>

        <span class="c1"># # TODO: need to prove that this is sufficient for non-identifiability!</span>
        <span class="c1"># if len(self.X.intersection(self._get_descendants(self.M))) &gt; 0:</span>
        <span class="c1">#     raise ValueError(&quot;Not identifiable: Overlap between X and des(M)&quot;)</span>

        <span class="k">if</span> <span class="n">check_SM_overlap</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">S</span><span class="o">.</span><span class="n">intersection</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">M</span><span class="p">))</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Conditions S overlap with mediators M.&quot;</span><span class="p">)</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">desX</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_descendants</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">desY</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_descendants</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">desM</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_descendants</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">M</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">descendants</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">desY</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">desM</span><span class="p">)</span>

        <span class="c1"># Define forb as X and descendants of YM</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">forbidden_nodes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">descendants</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">)</span>  <span class="c1">#.union(S)</span>

        <span class="c1"># Define valid ancestors</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">vancs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ancX</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ancY</span><span class="p">)</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ancS</span><span class="p">)</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">forbidden_nodes</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">S</span><span class="o">.</span><span class="n">intersection</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">desX</span><span class="p">))</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Warning: Potentially outside assumptions: Conditions S overlap with des(X)&quot;</span><span class="p">)</span>

        <span class="c1"># Here only check if S overlaps with des(Y), leave the option that S</span>
        <span class="c1"># contains variables in des(M) to the user</span>
        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">S</span><span class="o">.</span><span class="n">intersection</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">desY</span><span class="p">))</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Not identifiable: Conditions S overlap with des(Y).&quot;</span><span class="p">)</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">##</span><span class="se">\n</span><span class="s2">## Initializing CausalEffects class</span><span class="se">\n</span><span class="s2">##&quot;</span>
                  <span class="s2">&quot;</span><span class="se">\n\n</span><span class="s2">Input:&quot;</span><span class="p">)</span>
            <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">graph_type = </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">graph_type</span>
                  <span class="o">+</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">X = </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">listX</span>
                  <span class="o">+</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Y = </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">listY</span>
                  <span class="o">+</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">S = </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">listS</span>
                  <span class="o">+</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">M = </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">listM</span>
                  <span class="p">)</span>
            <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">hidden_variables</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">hidden_variables = </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_variables</span>
                      <span class="p">)</span> 
            <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n\n</span><span class="s2">&quot;</span><span class="p">)</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">no_causal_path</span><span class="p">:</span>
                <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;No causal path from X to Y exists!&quot;</span><span class="p">)</span>


    <span class="k">def</span> <span class="nf">_construct_graph</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">,</span> <span class="n">graph_type</span><span class="p">,</span> <span class="n">hidden_variables</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Construct internal graph object based on input graph and hidden variables.</span>

<span class="sd">           Uses the latent projection operation.</span>
<span class="sd">        &quot;&quot;&quot;</span>


        <span class="k">if</span> <span class="n">graph_type</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;dag&#39;</span><span class="p">,</span> <span class="s1">&#39;admg&#39;</span><span class="p">]:</span> 
            <span class="k">if</span> <span class="n">graph</span><span class="o">.</span><span class="n">ndim</span> <span class="o">!=</span> <span class="mi">2</span><span class="p">:</span>
                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;graph_type in [&#39;dag&#39;, &#39;admg&#39;] assumes graph.shape=(N, N).&quot;</span><span class="p">)</span>
            <span class="c1"># Convert to shape [N, N, 1, 1] with dummy dimension</span>
            <span class="c1"># to process as tsg_dag or tsg_admg with potential hidden variables</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">graph</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">graph</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
            
            <span class="c1"># tau_max needed in _get_latent_projection_graph</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">=</span> <span class="mi">0</span>

            <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">hidden_variables</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_latent_projection_graph</span><span class="p">()</span> <span class="c1"># stationary=False)</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">graph_type</span> <span class="o">=</span> <span class="s2">&quot;tsg_admg&quot;</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="c1"># graph = self.graph</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">graph_type</span> <span class="o">=</span> <span class="s1">&#39;tsg_&#39;</span> <span class="o">+</span> <span class="n">graph_type</span>

        <span class="k">elif</span> <span class="n">graph_type</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;tsg_dag&#39;</span><span class="p">,</span> <span class="s1">&#39;tsg_admg&#39;</span><span class="p">]:</span>
            <span class="k">if</span> <span class="n">graph</span><span class="o">.</span><span class="n">ndim</span> <span class="o">!=</span> <span class="mi">4</span><span class="p">:</span>
                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;tsg-graph_type assumes graph.shape=(N, N, tau_max+1, tau_max+1).&quot;</span><span class="p">)</span>

            <span class="c1"># Then tau_max is implicitely derived from</span>
            <span class="c1"># the dimensions </span>
            <span class="bp">self</span><span class="o">.</span><span class="n">graph</span> <span class="o">=</span> <span class="n">graph</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">=</span> <span class="n">graph</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span>

            <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">hidden_variables</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_latent_projection_graph</span><span class="p">()</span> <span class="c1">#, stationary=False)</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">graph_type</span> <span class="o">=</span> <span class="s2">&quot;tsg_admg&quot;</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">graph_type</span> <span class="o">=</span> <span class="n">graph_type</span>   

        <span class="k">elif</span> <span class="n">graph_type</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;stationary_dag&#39;</span><span class="p">,</span> <span class="s1">&#39;stationary_admg&#39;</span><span class="p">]:</span>
            <span class="c1"># Currently only stationary_dag without hidden variables is supported</span>
            <span class="k">if</span> <span class="n">graph</span><span class="o">.</span><span class="n">ndim</span> <span class="o">!=</span> <span class="mi">3</span><span class="p">:</span>
                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;stationary graph_type assumes graph.shape=(N, N, tau_max+1).&quot;</span><span class="p">)</span>
            
            <span class="c1"># # TODO: remove if theory for stationary ADMGs is clear</span>
            <span class="c1"># if graph_type == &#39;stationary_dag&#39; and len(hidden_variables) &gt; 0:</span>
            <span class="c1">#     raise ValueError(&quot;Hidden variables currently not supported for &quot;</span>
            <span class="c1">#                      &quot;stationary_dag.&quot;)</span>

            <span class="c1"># For a stationary DAG without hidden variables it&#39;s sufficient to consider</span>
            <span class="c1"># a tau_max that includes the parents of X, Y, M, and S. A conservative</span>
            <span class="c1"># estimate thereof is simply the lag-dimension of the stationary DAG plus</span>
            <span class="c1"># the maximum lag of XYS.</span>
            <span class="n">statgraph_tau_max</span> <span class="o">=</span> <span class="n">graph</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span>
            <span class="n">maxlag_XYS</span> <span class="o">=</span> <span class="mi">0</span>
            <span class="k">for</span> <span class="n">varlag</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="p">)</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">S</span><span class="p">):</span>
                <span class="n">maxlag_XYS</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">maxlag_XYS</span><span class="p">,</span> <span class="nb">abs</span><span class="p">(</span><span class="n">varlag</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>

            <span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">=</span> <span class="n">maxlag_XYS</span> <span class="o">+</span> <span class="n">statgraph_tau_max</span>

            <span class="n">stat_graph</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">graph</span><span class="p">)</span>

            <span class="c1">#########################################		</span>
            <span class="c1"># Use this tau_max and construct ADMG by assuming paths of</span>
            <span class="c1"># maximal lag 10*tau_max... TO BE REVISED!</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">graph</span> <span class="o">=</span> <span class="n">graph</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_latent_projection_graph</span><span class="p">(</span><span class="n">stationary</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">graph_type</span> <span class="o">=</span> <span class="s2">&quot;tsg_admg&quot;</span>
            <span class="c1">#########################################</span>

            <span class="c1"># Also create stationary graph extended to tau_max</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">stationary_graph</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">N</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">N</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;&lt;U3&#39;</span><span class="p">)</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">stationary_graph</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:</span><span class="n">stat_graph</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]]</span> <span class="o">=</span> <span class="n">stat_graph</span>

            <span class="c1"># allowed_edges = [&quot;--&gt;&quot;, &quot;&lt;--&quot;]</span>

            <span class="c1"># # Construct tsg_graph</span>
            <span class="c1"># graph = np.zeros((self.N, self.N, self.tau_max + 1, self.tau_max + 1), dtype=&#39;&lt;U3&#39;)</span>
            <span class="c1"># graph[:] = &quot;&quot;</span>
            <span class="c1"># for (i, j) in itertools.product(range(self.N), range(self.N)):</span>
            <span class="c1">#     for jt, tauj in enumerate(range(0, self.tau_max + 1)):</span>
            <span class="c1">#         for it, taui in enumerate(range(tauj, self.tau_max + 1)):</span>
            <span class="c1">#             tau = abs(taui - tauj)</span>
            <span class="c1">#             if tau == 0 and j == i:</span>
            <span class="c1">#                 continue</span>
            <span class="c1">#             if tau &gt; statgraph_tau_max:</span>
            <span class="c1">#                 continue                        </span>

            <span class="c1">#             # if tau == 0:</span>
            <span class="c1">#             #     if stat_graph[i, j, tau] == &#39;--&gt;&#39;:</span>
            <span class="c1">#             #         graph[i, j, taui, tauj] = &quot;--&gt;&quot; </span>
            <span class="c1">#             #         graph[j, i, tauj, taui] = &quot;&lt;--&quot; </span>

            <span class="c1">#             #     # elif stat_graph[i, j, tau] == &#39;&lt;--&#39;:</span>
            <span class="c1">#             #     #     graph[i, j, taui, tauj] = &quot;&lt;--&quot;</span>
            <span class="c1">#             #     #     graph[j, i, tauj, taui] = &quot;--&gt;&quot; </span>
            <span class="c1">#             # else:</span>
            <span class="c1">#             if stat_graph[i, j, tau] == &#39;--&gt;&#39;:</span>
            <span class="c1">#                 graph[i, j, taui, tauj] = &quot;--&gt;&quot; </span>
            <span class="c1">#                 graph[j, i, tauj, taui] = &quot;&lt;--&quot; </span>
            <span class="c1">#             elif stat_graph[i, j, tau] == &#39;&lt;--&#39;:</span>
            <span class="c1">#                 pass</span>
            <span class="c1">#             elif stat_graph[i, j, tau] == &#39;&#39;:</span>
            <span class="c1">#                 pass</span>
            <span class="c1">#             else:</span>
            <span class="c1">#                 edge = stat_graph[i, j, tau]</span>
            <span class="c1">#                 raise ValueError(&quot;Invalid graph edge %s. &quot; %(edge) +</span>
            <span class="c1">#                      &quot;For graph_type = %s only %s are allowed.&quot; %(graph_type, str(allowed_edges)))</span>
      
            <span class="c1">#             # elif stat_graph[i, j, tau] == &#39;&lt;--&#39;:</span>
            <span class="c1">#             #     graph[i, j, taui, tauj] = &quot;&lt;--&quot;</span>
            <span class="c1">#             #     graph[j, i, tauj, taui] = &quot;--&gt;&quot; </span>

            <span class="c1"># self.graph_type = &#39;tsg_dag&#39;</span>
            <span class="c1"># self.graph = graph</span>


        <span class="c1"># return (graph, graph_type, self.tau_max, hidden_variables)</span>

            <span class="c1"># max_lag = self._get_maximum_possible_lag(XYZ=list(X.union(Y).union(S)), graph=graph)</span>

            <span class="c1"># stat_mediators = self._get_mediators_stationary_graph(start=X, end=Y, max_lag=max_lag)</span>
            <span class="c1"># self.tau_max = self._get_maximum_possible_lag(XYZ=list(X.union(Y).union(S).union(stat_mediators)), graph=graph)</span>
            <span class="c1"># self.tau_max = graph_taumax</span>
            <span class="c1"># for varlag in X.union(Y).union(S):</span>
            <span class="c1">#     self.tau_max = max(self.tau_max, abs(varlag[1]))</span>

            <span class="c1"># if verbosity &gt; 0:</span>
            <span class="c1">#     print(&quot;Setting tau_max = &quot;, self.tau_max)</span>

            <span class="c1"># if tau_max is None:</span>
            <span class="c1">#     self.tau_max = graph_taumax</span>
            <span class="c1">#     for varlag in X.union(Y).union(S):</span>
            <span class="c1">#         self.tau_max = max(self.tau_max, abs(varlag[1]))</span>

            <span class="c1">#     if verbosity &gt; 0:</span>
            <span class="c1">#         print(&quot;Setting tau_max = &quot;, self.tau_max)</span>
            <span class="c1"># else:</span>
                <span class="c1"># self.tau_max = graph_taumax</span>
                <span class="c1"># # Repeat hidden variable pattern </span>
                <span class="c1"># # if larger tau_max is given</span>
                <span class="c1"># if self.tau_max &gt; graph_taumax:</span>
                <span class="c1">#     for lag in range(graph_taumax + 1, self.tau_max + 1):</span>
                <span class="c1">#         for j in range(self.N):</span>
                <span class="c1">#             if (j, -(lag % (graph_taumax+1))) in self.hidden_variables:</span>
                <span class="c1">#                 self.hidden_variables.add((j, -lag))</span>
            <span class="c1"># print(self.hidden_variables)</span>

        <span class="c1">#     self.graph = self._get_latent_projection_graph(self.graph, stationary=True)</span>
        <span class="c1">#     self.graph_type = &quot;tsg_admg&quot;</span>
        <span class="c1"># else:</span>

    <span class="k">def</span> <span class="nf">_check_XYS</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Check whether XYS are sober.</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="n">XYS</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="p">)</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">S</span><span class="p">)</span>
        <span class="k">for</span> <span class="n">xys</span> <span class="ow">in</span> <span class="n">XYS</span><span class="p">:</span>
            <span class="n">var</span><span class="p">,</span> <span class="n">lag</span> <span class="o">=</span> <span class="n">xys</span> 
            <span class="k">if</span> <span class="n">var</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">var</span> <span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">N</span><span class="p">:</span>
                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;XYS vars must be in [0...N]&quot;</span><span class="p">)</span>
            <span class="k">if</span> <span class="n">lag</span> <span class="o">&lt;</span> <span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="ow">or</span> <span class="n">lag</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;XYS lags must be in [-taumax...0]&quot;</span><span class="p">)</span>


<div class="viewcode-block" id="CausalEffects.check_XYS_paths"><a class="viewcode-back" href="../../index.html#tigramite.causal_effects.CausalEffects.check_XYS_paths">[docs]</a>    <span class="k">def</span> <span class="nf">check_XYS_paths</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Check whether one can remove nodes from X and Y with no proper causal paths.</span>

<span class="sd">        Returns</span>
<span class="sd">        -------</span>
<span class="sd">        X, Y : cleaned lists of X and Y with irrelevant nodes removed.</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="c1"># TODO: Also check S...</span>
        <span class="n">oldX</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
        <span class="n">oldY</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>

        <span class="c1"># anc_Y = self._get_ancestors(self.Y)</span>
        <span class="c1"># anc_S = self._get_ancestors(self.S)</span>

        <span class="c1"># Remove first from X those nodes with no causal path to Y or S</span>
        <span class="n">X</span> <span class="o">=</span> <span class="nb">set</span><span class="p">([</span><span class="n">x</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">X</span> <span class="k">if</span> <span class="n">x</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">ancY</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ancS</span><span class="p">)])</span>
        
        <span class="c1"># Remove from Y those nodes with no causal path from X</span>
        <span class="c1"># des_X = self._get_descendants(X)</span>

        <span class="n">Y</span> <span class="o">=</span> <span class="nb">set</span><span class="p">([</span><span class="n">y</span> <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">Y</span> <span class="k">if</span> <span class="n">y</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">desX</span><span class="p">])</span>

        <span class="c1"># Also require that all x in X have proper path to Y or S,</span>
        <span class="c1"># that is, the first link goes out of x </span>
        <span class="c1"># and into path nodes</span>
        <span class="n">mediators_S</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_mediators</span><span class="p">(</span><span class="n">start</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">,</span> <span class="n">end</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">S</span><span class="p">)</span>
        <span class="n">path_nodes</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">M</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">Y</span><span class="p">)</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">mediators_S</span><span class="p">))</span> 
        <span class="n">X</span> <span class="o">=</span> <span class="n">X</span><span class="o">.</span><span class="n">intersection</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_all_parents</span><span class="p">(</span><span class="n">path_nodes</span><span class="p">))</span>

        <span class="k">if</span> <span class="nb">set</span><span class="p">(</span><span class="n">oldX</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">set</span><span class="p">(</span><span class="n">X</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Consider pruning X = </span><span class="si">%s</span><span class="s2"> to X = </span><span class="si">%s</span><span class="s2"> &quot;</span> <span class="o">%</span><span class="p">(</span><span class="n">oldX</span><span class="p">,</span> <span class="n">X</span><span class="p">)</span> <span class="o">+</span>
                  <span class="s2">&quot;since only these have causal path to Y&quot;</span><span class="p">)</span>

        <span class="k">if</span> <span class="nb">set</span><span class="p">(</span><span class="n">oldY</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">set</span><span class="p">(</span><span class="n">Y</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Consider pruning Y = </span><span class="si">%s</span><span class="s2"> to Y = </span><span class="si">%s</span><span class="s2"> &quot;</span> <span class="o">%</span><span class="p">(</span><span class="n">oldY</span><span class="p">,</span> <span class="n">Y</span><span class="p">)</span> <span class="o">+</span>
                  <span class="s2">&quot;since only these have causal path from X&quot;</span><span class="p">)</span>

        <span class="k">return</span> <span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="nb">list</span><span class="p">(</span><span class="n">Y</span><span class="p">))</span></div>


    <span class="k">def</span> <span class="nf">_check_graph</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Checks that graph contains no invalid entries/structure.</span>

<span class="sd">        Assumes graph.shape = (N, N, tau_max+1, tau_max+1)</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="n">allowed_edges</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;--&gt;&quot;</span><span class="p">,</span> <span class="s2">&quot;&lt;--&quot;</span><span class="p">]</span>
        <span class="k">if</span> <span class="s1">&#39;admg&#39;</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">graph_type</span><span class="p">:</span>
            <span class="n">allowed_edges</span> <span class="o">+=</span> <span class="p">[</span><span class="s2">&quot;&lt;-&gt;&quot;</span><span class="p">,</span> <span class="s2">&quot;&lt;-+&quot;</span><span class="p">,</span> <span class="s2">&quot;+-&gt;&quot;</span><span class="p">]</span>
        <span class="k">elif</span> <span class="s1">&#39;mag&#39;</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">graph_type</span><span class="p">:</span>
            <span class="n">allowed_edges</span> <span class="o">+=</span> <span class="p">[</span><span class="s2">&quot;&lt;-&gt;&quot;</span><span class="p">]</span>
        <span class="k">elif</span> <span class="s1">&#39;pag&#39;</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">graph_type</span><span class="p">:</span>
            <span class="n">allowed_edges</span> <span class="o">+=</span> <span class="p">[</span><span class="s2">&quot;&lt;-&gt;&quot;</span><span class="p">,</span> <span class="s2">&quot;o-o&quot;</span><span class="p">,</span> <span class="s2">&quot;o-&gt;&quot;</span><span class="p">,</span> <span class="s2">&quot;&lt;-o&quot;</span><span class="p">]</span>                         <span class="c1"># &quot;o--&quot;,</span>
                        <span class="c1"># &quot;--o&quot;,</span>
                        <span class="c1"># &quot;x-o&quot;,</span>
                        <span class="c1"># &quot;o-x&quot;,</span>
                        <span class="c1"># &quot;x--&quot;,</span>
                        <span class="c1"># &quot;--x&quot;,</span>
                        <span class="c1"># &quot;x-&gt;&quot;,</span>
                        <span class="c1"># &quot;&lt;-x&quot;,</span>
                        <span class="c1"># &quot;x-x&quot;,</span>
                    <span class="c1"># ]</span>

        <span class="n">graph_dict</span> <span class="o">=</span> <span class="n">defaultdict</span><span class="p">(</span><span class="nb">list</span><span class="p">)</span>
        <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">taui</span><span class="p">,</span> <span class="n">tauj</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">graph</span><span class="p">)):</span>
            <span class="n">edge</span> <span class="o">=</span> <span class="n">graph</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">taui</span><span class="p">,</span> <span class="n">tauj</span><span class="p">]</span>
            <span class="c1"># print((i, -taui), edge, (j, -tauj), graph[j, i, tauj, taui])</span>
            <span class="k">if</span> <span class="n">edge</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_reverse_link</span><span class="p">(</span><span class="n">graph</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">tauj</span><span class="p">,</span> <span class="n">taui</span><span class="p">]):</span>
                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
                    <span class="s2">&quot;graph needs to have consistent edges (eg&quot;</span>
                    <span class="s2">&quot; graph[i,j,taui,tauj]=&#39;--&gt;&#39; requires graph[j,i,tauj,taui]=&#39;&lt;--&#39;)&quot;</span>
                <span class="p">)</span>

            <span class="k">if</span> <span class="n">edge</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">allowed_edges</span><span class="p">:</span>
                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Invalid graph edge </span><span class="si">%s</span><span class="s2">. &quot;</span> <span class="o">%</span><span class="p">(</span><span class="n">edge</span><span class="p">)</span> <span class="o">+</span>
                                 <span class="s2">&quot;For graph_type = </span><span class="si">%s</span><span class="s2"> only </span><span class="si">%s</span><span class="s2"> are allowed.&quot;</span> <span class="o">%</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">graph_type</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">allowed_edges</span><span class="p">)))</span>

            <span class="k">if</span> <span class="n">edge</span> <span class="o">==</span> <span class="s2">&quot;--&gt;&quot;</span> <span class="ow">or</span> <span class="n">edge</span> <span class="o">==</span> <span class="s2">&quot;+-&gt;&quot;</span><span class="p">:</span>
                <span class="c1"># Map to (i,-taui, j, tauj) graph</span>
                <span class="n">indexi</span> <span class="o">=</span> <span class="n">i</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="n">taui</span>
                <span class="n">indexj</span> <span class="o">=</span> <span class="n">j</span> <span class="o">*</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="n">tauj</span>

                <span class="n">graph_dict</span><span class="p">[</span><span class="n">indexj</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">indexi</span><span class="p">)</span>

        <span class="c1"># Check for cycles</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_cyclic</span><span class="p">(</span><span class="n">graph_dict</span><span class="p">):</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;graph is cyclic.&quot;</span><span class="p">)</span>

        <span class="c1"># if MAG: check for almost cycles</span>
        <span class="c1"># if PAG???</span>

    <span class="k">def</span> <span class="nf">_check_cyclic</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph_dict</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Return True if the graph_dict has a cycle.</span>

<span class="sd">        graph_dict must be represented as a dictionary mapping vertices to</span>
<span class="sd">        iterables of neighbouring vertices. For example:</span>

<span class="sd">        &gt;&gt;&gt; cyclic({1: (2,), 2: (3,), 3: (1,)})</span>
<span class="sd">        True</span>
<span class="sd">        &gt;&gt;&gt; cyclic({1: (2,), 2: (3,), 3: (4,)})</span>
<span class="sd">        False</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="n">path</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>
        <span class="n">visited</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>

        <span class="k">def</span> <span class="nf">visit</span><span class="p">(</span><span class="n">vertex</span><span class="p">):</span>
            <span class="k">if</span> <span class="n">vertex</span> <span class="ow">in</span> <span class="n">visited</span><span class="p">:</span>
                <span class="k">return</span> <span class="kc">False</span>
            <span class="n">visited</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">vertex</span><span class="p">)</span>
            <span class="n">path</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">vertex</span><span class="p">)</span>
            <span class="k">for</span> <span class="n">neighbour</span> <span class="ow">in</span> <span class="n">graph_dict</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">vertex</span><span class="p">,</span> <span class="p">()):</span>
                <span class="k">if</span> <span class="n">neighbour</span> <span class="ow">in</span> <span class="n">path</span> <span class="ow">or</span> <span class="n">visit</span><span class="p">(</span><span class="n">neighbour</span><span class="p">):</span>
                    <span class="k">return</span> <span class="kc">True</span>
            <span class="n">path</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="n">vertex</span><span class="p">)</span>
            <span class="k">return</span> <span class="kc">False</span>

        <span class="k">return</span> <span class="nb">any</span><span class="p">(</span><span class="n">visit</span><span class="p">(</span><span class="n">v</span><span class="p">)</span> <span class="k">for</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">graph_dict</span><span class="p">)</span>

<div class="viewcode-block" id="CausalEffects.get_mediators"><a class="viewcode-back" href="../../index.html#tigramite.causal_effects.CausalEffects.get_mediators">[docs]</a>    <span class="k">def</span> <span class="nf">get_mediators</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">start</span><span class="p">,</span> <span class="n">end</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Returns mediator variables on proper causal paths.</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        start : set</span>
<span class="sd">            Set of start nodes.</span>
<span class="sd">        end : set</span>
<span class="sd">            Set of end nodes.</span>

<span class="sd">        Returns</span>
<span class="sd">        -------</span>
<span class="sd">        mediators : set</span>
<span class="sd">            Mediators on causal paths from start to end.</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="n">des_X</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_descendants</span><span class="p">(</span><span class="n">start</span><span class="p">)</span>

        <span class="n">mediators</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>

        <span class="c1"># Walk along proper causal paths backwards from Y to X</span>
        <span class="c1"># potential_mediators = set()</span>
        <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">end</span><span class="p">:</span>
            <span class="n">j</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">y</span> 
            <span class="n">this_level</span> <span class="o">=</span> <span class="p">[</span><span class="n">y</span><span class="p">]</span>
            <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">this_level</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="n">next_level</span> <span class="o">=</span> <span class="p">[]</span>
                <span class="k">for</span> <span class="n">varlag</span> <span class="ow">in</span> <span class="n">this_level</span><span class="p">:</span>
                    <span class="k">for</span> <span class="n">parent</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_parents</span><span class="p">(</span><span class="n">varlag</span><span class="p">):</span>
                        <span class="n">i</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">parent</span>
                        <span class="c1"># print(varlag, parent, des_X)</span>
                        <span class="k">if</span> <span class="p">(</span><span class="n">parent</span> <span class="ow">in</span> <span class="n">des_X</span>
                            <span class="ow">and</span> <span class="n">parent</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">mediators</span>
                            <span class="c1"># and parent not in potential_mediators</span>
                            <span class="ow">and</span> <span class="n">parent</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">start</span>
                            <span class="ow">and</span> <span class="n">parent</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">end</span>
                            <span class="ow">and</span> <span class="p">(</span><span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">&lt;=</span> <span class="n">tau</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">)):</span> <span class="c1"># or self.ignore_time_bounds)):</span>
                            <span class="n">mediators</span> <span class="o">=</span> <span class="n">mediators</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">parent</span><span class="p">]))</span>
                            <span class="n">next_level</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">parent</span><span class="p">)</span>
                            
                <span class="n">this_level</span> <span class="o">=</span> <span class="n">next_level</span>  

        <span class="k">return</span> <span class="n">mediators</span></div>

    <span class="k">def</span> <span class="nf">_get_mediators_stationary_graph</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">start</span><span class="p">,</span> <span class="n">end</span><span class="p">,</span> <span class="n">max_lag</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Returns mediator variables on proper causal paths</span>
<span class="sd">           from X to Y in a stationary graph.&quot;&quot;&quot;</span>

        <span class="n">des_X</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_descendants_stationary_graph</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="n">max_lag</span><span class="p">)</span>

        <span class="n">mediators</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>

        <span class="c1"># Walk along proper causal paths backwards from Y to X</span>
        <span class="n">potential_mediators</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>
        <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">end</span><span class="p">:</span>
            <span class="n">j</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">y</span> 
            <span class="n">this_level</span> <span class="o">=</span> <span class="p">[</span><span class="n">y</span><span class="p">]</span>
            <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">this_level</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="n">next_level</span> <span class="o">=</span> <span class="p">[]</span>
                <span class="k">for</span> <span class="n">varlag</span> <span class="ow">in</span> <span class="n">this_level</span><span class="p">:</span>
                    <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">parent</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_adjacents_stationary_graph</span><span class="p">(</span><span class="n">graph</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">graph</span><span class="p">,</span> 
                                <span class="n">node</span><span class="o">=</span><span class="n">varlag</span><span class="p">,</span> <span class="n">patterns</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;&lt;*-&quot;</span><span class="p">,</span> <span class="s2">&quot;&lt;*+&quot;</span><span class="p">],</span> <span class="n">max_lag</span><span class="o">=</span><span class="n">max_lag</span><span class="p">,</span> <span class="n">exclude</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
                        <span class="n">i</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">parent</span>
                        <span class="k">if</span> <span class="p">(</span><span class="n">parent</span> <span class="ow">in</span> <span class="n">des_X</span>
                            <span class="ow">and</span> <span class="n">parent</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">mediators</span>
                            <span class="c1"># and parent not in potential_mediators</span>
                            <span class="ow">and</span> <span class="n">parent</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">start</span>
                            <span class="ow">and</span> <span class="n">parent</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">end</span>
                            <span class="c1"># and (-self.tau_max &lt;= tau &lt;= 0 or self.ignore_time_bounds)</span>
                            <span class="p">):</span>
                            <span class="n">mediators</span> <span class="o">=</span> <span class="n">mediators</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">parent</span><span class="p">]))</span>
                            <span class="n">next_level</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">parent</span><span class="p">)</span>
                            
                <span class="n">this_level</span> <span class="o">=</span> <span class="n">next_level</span>  

        <span class="k">return</span> <span class="n">mediators</span>

    <span class="k">def</span> <span class="nf">_reverse_link</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">link</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Reverse a given link, taking care to replace &gt; with &lt; and vice versa.&quot;&quot;&quot;</span>

        <span class="k">if</span> <span class="n">link</span> <span class="o">==</span> <span class="s2">&quot;&quot;</span><span class="p">:</span>
            <span class="k">return</span> <span class="s2">&quot;&quot;</span>

        <span class="k">if</span> <span class="n">link</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;&gt;&quot;</span><span class="p">:</span>
            <span class="n">left_mark</span> <span class="o">=</span> <span class="s2">&quot;&lt;&quot;</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">left_mark</span> <span class="o">=</span> <span class="n">link</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>

        <span class="k">if</span> <span class="n">link</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;&lt;&quot;</span><span class="p">:</span>
            <span class="n">right_mark</span> <span class="o">=</span> <span class="s2">&quot;&gt;&quot;</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">right_mark</span> <span class="o">=</span> <span class="n">link</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>

        <span class="k">return</span> <span class="n">left_mark</span> <span class="o">+</span> <span class="n">link</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">right_mark</span>

    <span class="k">def</span> <span class="nf">_match_link</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pattern</span><span class="p">,</span> <span class="n">link</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Matches pattern including wildcards with link.</span>
<span class="sd">           </span>
<span class="sd">           In an ADMG we have edge types [&quot;--&gt;&quot;, &quot;&lt;--&quot;, &quot;&lt;-&gt;&quot;, &quot;+-&gt;&quot;, &quot;&lt;-+&quot;].</span>
<span class="sd">           Here +-&gt; corresponds to having both &quot;--&gt;&quot; and &quot;&lt;-&gt;&quot;.</span>

<span class="sd">           In a MAG we have edge types   [&quot;--&gt;&quot;, &quot;&lt;--&quot;, &quot;&lt;-&gt;&quot;, &quot;---&quot;].</span>
<span class="sd">        &quot;&quot;&quot;</span>
        
        <span class="k">if</span> <span class="n">pattern</span> <span class="o">==</span> <span class="s1">&#39;&#39;</span> <span class="ow">or</span> <span class="n">link</span> <span class="o">==</span> <span class="s1">&#39;&#39;</span><span class="p">:</span>
            <span class="k">return</span> <span class="kc">True</span> <span class="k">if</span> <span class="n">pattern</span> <span class="o">==</span> <span class="n">link</span> <span class="k">else</span> <span class="kc">False</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">left_mark</span><span class="p">,</span> <span class="n">middle_mark</span><span class="p">,</span> <span class="n">right_mark</span> <span class="o">=</span> <span class="n">pattern</span>
            <span class="k">if</span> <span class="n">left_mark</span> <span class="o">!=</span> <span class="s1">&#39;*&#39;</span><span class="p">:</span>
                <span class="c1"># if link[0] != &#39;+&#39;:</span>
                    <span class="k">if</span> <span class="n">link</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="n">left_mark</span><span class="p">:</span> <span class="k">return</span> <span class="kc">False</span>

            <span class="k">if</span> <span class="n">right_mark</span> <span class="o">!=</span> <span class="s1">&#39;*&#39;</span><span class="p">:</span>
                <span class="c1"># if link[2] != &#39;+&#39;:</span>
                    <span class="k">if</span> <span class="n">link</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">!=</span> <span class="n">right_mark</span><span class="p">:</span> <span class="k">return</span> <span class="kc">False</span> 
            
            <span class="k">if</span> <span class="n">middle_mark</span> <span class="o">!=</span> <span class="s1">&#39;*&#39;</span> <span class="ow">and</span> <span class="n">link</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">!=</span> <span class="n">middle_mark</span><span class="p">:</span> <span class="k">return</span> <span class="kc">False</span>    
                       
            <span class="k">return</span> <span class="kc">True</span>

    <span class="k">def</span> <span class="nf">_find_adj</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">node</span><span class="p">,</span> <span class="n">patterns</span><span class="p">,</span> <span class="n">exclude</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">return_link</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Find adjacencies of node that match given patterns.&quot;&quot;&quot;</span>
        
        <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">graph</span>

        <span class="k">if</span> <span class="n">exclude</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">exclude</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="c1">#     exclude = self.hidden_variables</span>
        <span class="c1"># else:</span>
        <span class="c1">#     exclude = set(exclude).union(self.hidden_variables)</span>

        <span class="c1"># Setup</span>
        <span class="n">i</span><span class="p">,</span> <span class="n">lag_i</span> <span class="o">=</span> <span class="n">node</span>
        <span class="n">lag_i</span> <span class="o">=</span> <span class="nb">abs</span><span class="p">(</span><span class="n">lag_i</span><span class="p">)</span>

        <span class="k">if</span> <span class="n">exclude</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> <span class="n">exclude</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">patterns</span><span class="p">)</span> <span class="o">==</span> <span class="nb">str</span><span class="p">:</span>
            <span class="n">patterns</span> <span class="o">=</span> <span class="p">[</span><span class="n">patterns</span><span class="p">]</span>

        <span class="c1"># Init</span>
        <span class="n">adj</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="c1"># Find adjacencies going forward/contemp</span>
        <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">lag_ik</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">graph</span><span class="p">[</span><span class="n">i</span><span class="p">,:,</span><span class="n">lag_i</span><span class="p">,:])):</span>
            <span class="c1"># print((k, lag_ik), graph[i,k,lag_i,lag_ik]) </span>
            <span class="c1"># matches = [self._match_link(patt, graph[i,k,lag_i,lag_ik]) for patt in patterns]</span>
            <span class="c1"># if np.any(matches):</span>
            <span class="k">for</span> <span class="n">patt</span> <span class="ow">in</span> <span class="n">patterns</span><span class="p">:</span>
                <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_match_link</span><span class="p">(</span><span class="n">patt</span><span class="p">,</span> <span class="n">graph</span><span class="p">[</span><span class="n">i</span><span class="p">,</span><span class="n">k</span><span class="p">,</span><span class="n">lag_i</span><span class="p">,</span><span class="n">lag_ik</span><span class="p">]):</span>
                    <span class="n">match</span> <span class="o">=</span> <span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="o">-</span><span class="n">lag_ik</span><span class="p">)</span>
                    <span class="k">if</span> <span class="n">match</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">exclude</span><span class="p">:</span>
                        <span class="k">if</span> <span class="n">return_link</span><span class="p">:</span>
                            <span class="n">adj</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">graph</span><span class="p">[</span><span class="n">i</span><span class="p">,</span><span class="n">k</span><span class="p">,</span><span class="n">lag_i</span><span class="p">,</span><span class="n">lag_ik</span><span class="p">],</span> <span class="n">match</span><span class="p">))</span>
                        <span class="k">else</span><span class="p">:</span>
                            <span class="n">adj</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">match</span><span class="p">)</span>
                    <span class="k">break</span>

        
        <span class="c1"># Find adjacencies going backward/contemp</span>
        <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">lag_ki</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">graph</span><span class="p">[:,</span><span class="n">i</span><span class="p">,:,</span><span class="n">lag_i</span><span class="p">])):</span>  
            <span class="c1"># print((k, lag_ki), graph[k,i,lag_ki,lag_i]) </span>
            <span class="c1"># matches = [self._match_link(self._reverse_link(patt), graph[k,i,lag_ki,lag_i]) for patt in patterns]</span>
            <span class="c1"># if np.any(matches):</span>
            <span class="k">for</span> <span class="n">patt</span> <span class="ow">in</span> <span class="n">patterns</span><span class="p">:</span>
                <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_match_link</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_reverse_link</span><span class="p">(</span><span class="n">patt</span><span class="p">),</span> <span class="n">graph</span><span class="p">[</span><span class="n">k</span><span class="p">,</span><span class="n">i</span><span class="p">,</span><span class="n">lag_ki</span><span class="p">,</span><span class="n">lag_i</span><span class="p">]):</span>
                    <span class="n">match</span> <span class="o">=</span> <span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="o">-</span><span class="n">lag_ki</span><span class="p">)</span>
                    <span class="k">if</span> <span class="n">match</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">exclude</span><span class="p">:</span>
                        <span class="k">if</span> <span class="n">return_link</span><span class="p">:</span>
                            <span class="n">adj</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">_reverse_link</span><span class="p">(</span><span class="n">graph</span><span class="p">[</span><span class="n">k</span><span class="p">,</span><span class="n">i</span><span class="p">,</span><span class="n">lag_ki</span><span class="p">,</span><span class="n">lag_i</span><span class="p">]),</span> <span class="n">match</span><span class="p">))</span>
                        <span class="k">else</span><span class="p">:</span>
                            <span class="n">adj</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">match</span><span class="p">)</span>
                    <span class="k">break</span>
     
        <span class="n">adj</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">adj</span><span class="p">))</span>
        <span class="k">return</span> <span class="n">adj</span>

    <span class="k">def</span> <span class="nf">_is_match</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">nodei</span><span class="p">,</span> <span class="n">nodej</span><span class="p">,</span> <span class="n">pattern_ij</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Check whether the link between X and Y agrees with pattern.&quot;&quot;&quot;</span>

        <span class="n">graph</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">graph</span>

        <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">lag_i</span><span class="p">)</span> <span class="o">=</span> <span class="n">nodei</span>
        <span class="p">(</span><span class="n">j</span><span class="p">,</span> <span class="n">lag_j</span><span class="p">)</span> <span class="o">=</span> <span class="n">nodej</span>
        <span class="n">tauij</span> <span class="o">=</span> <span class="n">lag_j</span> <span class="o">-</span> <span class="n">lag_i</span>
        <span class="k">if</span> <span class="nb">abs</span><span class="p">(</span><span class="n">tauij</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="n">graph</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]:</span>
            <span class="k">return</span> <span class="kc">False</span>
        <span class="k">return</span> <span class="p">((</span><span class="n">tauij</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">_match_link</span><span class="p">(</span><span class="n">pattern_ij</span><span class="p">,</span> <span class="n">graph</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">tauij</span><span class="p">]))</span> <span class="ow">or</span>
               <span class="p">(</span><span class="n">tauij</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">_match_link</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_reverse_link</span><span class="p">(</span><span class="n">pattern_ij</span><span class="p">),</span> <span class="n">graph</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="nb">abs</span><span class="p">(</span><span class="n">tauij</span><span class="p">)])))</span>

    <span class="k">def</span> <span class="nf">_get_children</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">varlag</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Returns set of children (varlag --&gt; ...) for (lagged) varlag.&quot;&quot;&quot;</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">possible</span><span class="p">:</span>
            <span class="n">patterns</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;-*&gt;&#39;</span><span class="p">,</span> <span class="s1">&#39;o*o&#39;</span><span class="p">,</span> <span class="s1">&#39;o*&gt;&#39;</span><span class="p">]</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">patterns</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;-*&gt;&#39;</span><span class="p">,</span> <span class="s1">&#39;+*&gt;&#39;</span><span class="p">]</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_find_adj</span><span class="p">(</span><span class="n">node</span><span class="o">=</span><span class="n">varlag</span><span class="p">,</span> <span class="n">patterns</span><span class="o">=</span><span class="n">patterns</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">_get_parents</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">varlag</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Returns set of parents (varlag &lt;-- ...) for (lagged) varlag.&quot;&quot;&quot;</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">possible</span><span class="p">:</span>
            <span class="n">patterns</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;&lt;*-&#39;</span><span class="p">,</span> <span class="s1">&#39;o*o&#39;</span><span class="p">,</span> <span class="s1">&#39;&lt;*o&#39;</span><span class="p">]</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">patterns</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;&lt;*-&#39;</span><span class="p">,</span> <span class="s1">&#39;&lt;*+&#39;</span><span class="p">]</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_find_adj</span><span class="p">(</span><span class="n">node</span><span class="o">=</span><span class="n">varlag</span><span class="p">,</span> <span class="n">patterns</span><span class="o">=</span><span class="n">patterns</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">_get_spouses</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">varlag</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Returns set of spouses (varlag &lt;-&gt; ...)  for (lagged) varlag.&quot;&quot;&quot;</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_find_adj</span><span class="p">(</span><span class="n">node</span><span class="o">=</span><span class="n">varlag</span><span class="p">,</span> <span class="n">patterns</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;&lt;*&gt;&#39;</span><span class="p">,</span> <span class="s1">&#39;+*&gt;&#39;</span><span class="p">,</span> <span class="s1">&#39;&lt;*+&#39;</span><span class="p">])</span>

    <span class="k">def</span> <span class="nf">_get_neighbors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">varlag</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Returns set of neighbors (varlag --- ...) for (lagged) varlag.&quot;&quot;&quot;</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_find_adj</span><span class="p">(</span><span class="n">node</span><span class="o">=</span><span class="n">varlag</span><span class="p">,</span> <span class="n">patterns</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;-*-&#39;</span><span class="p">])</span>

    <span class="k">def</span> <span class="nf">_get_ancestors</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">W</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Get ancestors of nodes in W up to time tau_max.</span>
<span class="sd">        </span>
<span class="sd">        Includes the nodes themselves.</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="n">ancestors</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">W</span><span class="p">)</span>

        <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">W</span><span class="p">:</span>
            <span class="n">j</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">w</span> 
            <span class="n">this_level</span> <span class="o">=</span> <span class="p">[</span><span class="n">w</span><span class="p">]</span>
            <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">this_level</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="n">next_level</span> <span class="o">=</span> <span class="p">[]</span>
                <span class="k">for</span> <span class="n">varlag</span> <span class="ow">in</span> <span class="n">this_level</span><span class="p">:</span>

                    <span class="k">for</span> <span class="n">par</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_parents</span><span class="p">(</span><span class="n">varlag</span><span class="p">):</span>
                        <span class="n">i</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">par</span>
                        <span class="k">if</span> <span class="n">par</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">ancestors</span> <span class="ow">and</span> <span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">&lt;=</span> <span class="n">tau</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">:</span>
                            <span class="n">ancestors</span> <span class="o">=</span> <span class="n">ancestors</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">par</span><span class="p">]))</span>
                            <span class="n">next_level</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">par</span><span class="p">)</span>

                <span class="n">this_level</span> <span class="o">=</span> <span class="n">next_level</span>       

        <span class="k">return</span> <span class="n">ancestors</span>

    <span class="k">def</span> <span class="nf">_get_all_parents</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">W</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Get parents of nodes in W up to time tau_max.</span>
<span class="sd">        </span>
<span class="sd">        Includes the nodes themselves.</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="n">parents</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">W</span><span class="p">)</span>

        <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">W</span><span class="p">:</span>
            <span class="n">j</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">w</span> 
            <span class="k">for</span> <span class="n">par</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_parents</span><span class="p">(</span><span class="n">w</span><span class="p">):</span>
                <span class="n">i</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">par</span>
                <span class="k">if</span> <span class="n">par</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">parents</span> <span class="ow">and</span> <span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">&lt;=</span> <span class="n">tau</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">:</span>
                    <span class="n">parents</span> <span class="o">=</span> <span class="n">parents</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">par</span><span class="p">]))</span>

        <span class="k">return</span> <span class="n">parents</span>

    <span class="k">def</span> <span class="nf">_get_all_spouses</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">W</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Get spouses of nodes in W up to time tau_max.</span>
<span class="sd">        </span>
<span class="sd">        Includes the nodes themselves.</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="n">spouses</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">W</span><span class="p">)</span>

        <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">W</span><span class="p">:</span>
            <span class="n">j</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">w</span> 
            <span class="k">for</span> <span class="n">spouse</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_spouses</span><span class="p">(</span><span class="n">w</span><span class="p">):</span>
                <span class="n">i</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">spouse</span>
                <span class="k">if</span> <span class="n">spouse</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">spouses</span> <span class="ow">and</span> <span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">&lt;=</span> <span class="n">tau</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">:</span>
                    <span class="n">spouses</span> <span class="o">=</span> <span class="n">spouses</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">spouse</span><span class="p">]))</span>

        <span class="k">return</span> <span class="n">spouses</span>

    <span class="k">def</span> <span class="nf">_get_descendants_stationary_graph</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">max_lag</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Get descendants of nodes in W up to time t in stationary graph.</span>
<span class="sd">        </span>
<span class="sd">        Includes the nodes themselves.</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="n">descendants</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">W</span><span class="p">)</span>

        <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">W</span><span class="p">:</span>
            <span class="n">j</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">w</span> 
            <span class="n">this_level</span> <span class="o">=</span> <span class="p">[</span><span class="n">w</span><span class="p">]</span>
            <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">this_level</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="n">next_level</span> <span class="o">=</span> <span class="p">[]</span>
                <span class="k">for</span> <span class="n">varlag</span> <span class="ow">in</span> <span class="n">this_level</span><span class="p">:</span>
                    <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">child</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_adjacents_stationary_graph</span><span class="p">(</span><span class="n">graph</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">graph</span><span class="p">,</span> 
                                <span class="n">node</span><span class="o">=</span><span class="n">varlag</span><span class="p">,</span> <span class="n">patterns</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;-*&gt;&quot;</span><span class="p">,</span> <span class="s2">&quot;-*+&quot;</span><span class="p">],</span> <span class="n">max_lag</span><span class="o">=</span><span class="n">max_lag</span><span class="p">,</span> <span class="n">exclude</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
                        <span class="n">i</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">child</span>
                        <span class="k">if</span> <span class="p">(</span><span class="n">child</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">descendants</span> 
                            <span class="c1"># and (-self.tau_max &lt;= tau &lt;= 0 or self.ignore_time_bounds)</span>
                            <span class="p">):</span>
                            <span class="n">descendants</span> <span class="o">=</span> <span class="n">descendants</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">child</span><span class="p">]))</span>
                            <span class="n">next_level</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">child</span><span class="p">)</span>

                <span class="n">this_level</span> <span class="o">=</span> <span class="n">next_level</span>       

        <span class="k">return</span> <span class="n">descendants</span>

    <span class="k">def</span> <span class="nf">_get_descendants</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">W</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Get descendants of nodes in W up to time t.</span>
<span class="sd">        </span>
<span class="sd">        Includes the nodes themselves.</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="n">descendants</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">W</span><span class="p">)</span>

        <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">W</span><span class="p">:</span>
            <span class="n">j</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">w</span> 
            <span class="n">this_level</span> <span class="o">=</span> <span class="p">[</span><span class="n">w</span><span class="p">]</span>
            <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">this_level</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="n">next_level</span> <span class="o">=</span> <span class="p">[]</span>
                <span class="k">for</span> <span class="n">varlag</span> <span class="ow">in</span> <span class="n">this_level</span><span class="p">:</span>
                    <span class="k">for</span> <span class="n">child</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_children</span><span class="p">(</span><span class="n">varlag</span><span class="p">):</span>
                        <span class="n">i</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">child</span>
                        <span class="k">if</span> <span class="p">(</span><span class="n">child</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">descendants</span> 
                            <span class="ow">and</span> <span class="p">(</span><span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">&lt;=</span> <span class="n">tau</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">)):</span> <span class="c1"># or self.ignore_time_bounds)):</span>
                            <span class="n">descendants</span> <span class="o">=</span> <span class="n">descendants</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">child</span><span class="p">]))</span>
                            <span class="n">next_level</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">child</span><span class="p">)</span>

                <span class="n">this_level</span> <span class="o">=</span> <span class="n">next_level</span>       

        <span class="k">return</span> <span class="n">descendants</span>

    <span class="k">def</span> <span class="nf">_get_collider_path_nodes</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">descendants</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Get non-descendant collider path nodes and their parents of nodes in W up to time t.</span>
<span class="sd">        </span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="n">collider_path_nodes</span> <span class="o">=</span> <span class="nb">set</span><span class="p">([])</span>
        <span class="c1"># print(&quot;descendants &quot;, descendants)</span>
        <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">W</span><span class="p">:</span>
            <span class="c1"># print(w)</span>
            <span class="n">j</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">w</span> 
            <span class="n">this_level</span> <span class="o">=</span> <span class="p">[</span><span class="n">w</span><span class="p">]</span>
            <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">this_level</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="n">next_level</span> <span class="o">=</span> <span class="p">[]</span>
                <span class="k">for</span> <span class="n">varlag</span> <span class="ow">in</span> <span class="n">this_level</span><span class="p">:</span>
                    <span class="c1"># print(&quot;\t&quot;, varlag, self._get_spouses(varlag))</span>
                    <span class="k">for</span> <span class="n">spouse</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_spouses</span><span class="p">(</span><span class="n">varlag</span><span class="p">):</span>
                        <span class="c1"># print(&quot;\t\t&quot;, spouse)</span>
                        <span class="n">i</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">spouse</span>
                        <span class="k">if</span> <span class="p">(</span><span class="n">spouse</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">collider_path_nodes</span>
                            <span class="ow">and</span> <span class="n">spouse</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">descendants</span> 
                            <span class="ow">and</span> <span class="p">(</span><span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">&lt;=</span> <span class="n">tau</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">)):</span> <span class="c1"># or self.ignore_time_bounds)):</span>
                            <span class="n">collider_path_nodes</span> <span class="o">=</span> <span class="n">collider_path_nodes</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">spouse</span><span class="p">]))</span>
                            <span class="n">next_level</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">spouse</span><span class="p">)</span>

                <span class="n">this_level</span> <span class="o">=</span> <span class="n">next_level</span>       

        <span class="c1"># Add parents</span>
        <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">collider_path_nodes</span><span class="p">:</span>
            <span class="k">for</span> <span class="n">par</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_parents</span><span class="p">(</span><span class="n">w</span><span class="p">):</span>
                <span class="k">if</span> <span class="p">(</span><span class="n">par</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">collider_path_nodes</span>
                    <span class="ow">and</span> <span class="n">par</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">descendants</span>
                    <span class="ow">and</span> <span class="p">(</span><span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">&lt;=</span> <span class="n">tau</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">)):</span> <span class="c1"># or self.ignore_time_bounds)):</span>
                    <span class="n">collider_path_nodes</span> <span class="o">=</span> <span class="n">collider_path_nodes</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">par</span><span class="p">]))</span>

        <span class="k">return</span> <span class="n">collider_path_nodes</span>

    <span class="k">def</span> <span class="nf">_get_adjacents_stationary_graph</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">,</span> <span class="n">node</span><span class="p">,</span> <span class="n">patterns</span><span class="p">,</span> 
        <span class="n">max_lag</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">exclude</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Find adjacencies of node matching patterns in a stationary graph.&quot;&quot;&quot;</span>
        
        <span class="c1"># graph = self.graph</span>

        <span class="c1"># Setup</span>
        <span class="n">i</span><span class="p">,</span> <span class="n">lag_i</span> <span class="o">=</span> <span class="n">node</span>
        <span class="k">if</span> <span class="n">exclude</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> <span class="n">exclude</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">patterns</span><span class="p">)</span> <span class="o">==</span> <span class="nb">str</span><span class="p">:</span>
            <span class="n">patterns</span> <span class="o">=</span> <span class="p">[</span><span class="n">patterns</span><span class="p">]</span>

        <span class="c1"># Init</span>
        <span class="n">adj</span> <span class="o">=</span> <span class="p">[]</span>

        <span class="c1"># Find adjacencies going forward/contemp</span>
        <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">lag_ik</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">graph</span><span class="p">[</span><span class="n">i</span><span class="p">,:,:])):</span>  
            <span class="n">matches</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_match_link</span><span class="p">(</span><span class="n">patt</span><span class="p">,</span> <span class="n">graph</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">lag_ik</span><span class="p">])</span> <span class="k">for</span> <span class="n">patt</span> <span class="ow">in</span> <span class="n">patterns</span><span class="p">]</span>
            <span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">matches</span><span class="p">):</span>
                <span class="n">match</span> <span class="o">=</span> <span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">lag_i</span> <span class="o">+</span> <span class="n">lag_ik</span><span class="p">)</span>
                <span class="k">if</span> <span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">lag_i</span> <span class="o">+</span> <span class="n">lag_ik</span><span class="p">)</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">exclude</span> <span class="ow">and</span> <span class="p">(</span><span class="o">-</span><span class="n">max_lag</span> <span class="o">&lt;=</span> <span class="n">lag_i</span> <span class="o">+</span> <span class="n">lag_ik</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">):</span> <span class="c1"># or self.ignore_time_bounds):</span>
                    <span class="n">adj</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">graph</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">lag_ik</span><span class="p">],</span> <span class="n">match</span><span class="p">))</span>
        
        <span class="c1"># Find adjacencies going backward/contemp</span>
        <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">lag_ki</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">graph</span><span class="p">[:,</span><span class="n">i</span><span class="p">,:])):</span>  
            <span class="n">matches</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_match_link</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_reverse_link</span><span class="p">(</span><span class="n">patt</span><span class="p">),</span> <span class="n">graph</span><span class="p">[</span><span class="n">k</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">lag_ki</span><span class="p">])</span> <span class="k">for</span> <span class="n">patt</span> <span class="ow">in</span> <span class="n">patterns</span><span class="p">]</span>
            <span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">matches</span><span class="p">):</span>
                <span class="n">match</span> <span class="o">=</span> <span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">lag_i</span> <span class="o">-</span> <span class="n">lag_ki</span><span class="p">)</span>
                <span class="k">if</span> <span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">lag_i</span> <span class="o">-</span> <span class="n">lag_ki</span><span class="p">)</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">exclude</span> <span class="ow">and</span> <span class="p">(</span><span class="o">-</span><span class="n">max_lag</span> <span class="o">&lt;=</span> <span class="n">lag_i</span> <span class="o">-</span> <span class="n">lag_ki</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">):</span> <span class="c1"># or self.ignore_time_bounds):</span>
                    <span class="n">adj</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">_reverse_link</span><span class="p">(</span><span class="n">graph</span><span class="p">[</span><span class="n">k</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">lag_ki</span><span class="p">]),</span> <span class="n">match</span><span class="p">))</span>         
        
        <span class="n">adj</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">adj</span><span class="p">))</span>
        <span class="k">return</span> <span class="n">adj</span>

    <span class="k">def</span> <span class="nf">_get_canonical_dag_from_graph</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">graph</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Constructs canonical DAG as links_coeffs dictionary from graph.</span>

<span class="sd">        For every &lt;-&gt; link further latent variables are added.</span>
<span class="sd">        This corresponds to a canonical DAG (Richardson Spirtes 2002).</span>

<span class="sd">        Can be used to evaluate d-separation.</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="n">N</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">tau_maxplusone</span> <span class="o">=</span> <span class="n">graph</span><span class="o">.</span><span class="n">shape</span>
        <span class="n">tau_max</span> <span class="o">=</span> <span class="n">tau_maxplusone</span> <span class="o">-</span> <span class="mi">1</span>

        <span class="n">links</span> <span class="o">=</span> <span class="p">{</span><span class="n">j</span><span class="p">:</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">N</span><span class="p">)}</span>

        <span class="c1"># Add further latent variables to accommodate &lt;-&gt; links</span>
        <span class="n">latent_index</span> <span class="o">=</span> <span class="n">N</span>
        <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">tau</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">graph</span><span class="p">)):</span>

            <span class="n">edge_type</span> <span class="o">=</span> <span class="n">graph</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">tau</span><span class="p">]</span>

            <span class="c1"># Consider contemporaneous links only once</span>
            <span class="k">if</span> <span class="n">tau</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">j</span> <span class="o">&gt;</span> <span class="n">i</span><span class="p">:</span>
                <span class="k">continue</span>

            <span class="k">if</span> <span class="n">edge_type</span> <span class="o">==</span> <span class="s2">&quot;--&gt;&quot;</span><span class="p">:</span>
                <span class="n">links</span><span class="p">[</span><span class="n">j</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">i</span><span class="p">,</span> <span class="o">-</span><span class="n">tau</span><span class="p">))</span>
            <span class="k">elif</span> <span class="n">edge_type</span> <span class="o">==</span> <span class="s2">&quot;&lt;--&quot;</span><span class="p">:</span>
                <span class="n">links</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">j</span><span class="p">,</span> <span class="o">-</span><span class="n">tau</span><span class="p">))</span>
            <span class="k">elif</span> <span class="n">edge_type</span> <span class="o">==</span> <span class="s2">&quot;&lt;-&gt;&quot;</span><span class="p">:</span>
                <span class="n">links</span><span class="p">[</span><span class="n">latent_index</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
                <span class="n">links</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">latent_index</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
                <span class="n">links</span><span class="p">[</span><span class="n">j</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">latent_index</span><span class="p">,</span> <span class="o">-</span><span class="n">tau</span><span class="p">))</span>
                <span class="n">latent_index</span> <span class="o">+=</span> <span class="mi">1</span>
            <span class="c1"># elif edge_type == &quot;---&quot;:</span>
            <span class="c1">#     links[latent_index] = []</span>
            <span class="c1">#     selection_vars.append(latent_index)</span>
            <span class="c1">#     links[latent_index].append((i, -tau))</span>
            <span class="c1">#     links[latent_index].append((j, 0))</span>
            <span class="c1">#     latent_index += 1</span>
            <span class="k">elif</span> <span class="n">edge_type</span> <span class="o">==</span> <span class="s2">&quot;+-&gt;&quot;</span><span class="p">:</span>
                <span class="n">links</span><span class="p">[</span><span class="n">j</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">i</span><span class="p">,</span> <span class="o">-</span><span class="n">tau</span><span class="p">))</span>
                <span class="n">links</span><span class="p">[</span><span class="n">latent_index</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
                <span class="n">links</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">latent_index</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
                <span class="n">links</span><span class="p">[</span><span class="n">j</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">latent_index</span><span class="p">,</span> <span class="o">-</span><span class="n">tau</span><span class="p">))</span>
                <span class="n">latent_index</span> <span class="o">+=</span> <span class="mi">1</span>
            <span class="k">elif</span> <span class="n">edge_type</span> <span class="o">==</span> <span class="s2">&quot;&lt;-+&quot;</span><span class="p">:</span>
                <span class="n">links</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">j</span><span class="p">,</span> <span class="o">-</span><span class="n">tau</span><span class="p">))</span>
                <span class="n">links</span><span class="p">[</span><span class="n">latent_index</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
                <span class="n">links</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">latent_index</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
                <span class="n">links</span><span class="p">[</span><span class="n">j</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">latent_index</span><span class="p">,</span> <span class="o">-</span><span class="n">tau</span><span class="p">))</span>
                <span class="n">latent_index</span> <span class="o">+=</span> <span class="mi">1</span>

        <span class="k">return</span> <span class="n">links</span>


    <span class="k">def</span> <span class="nf">_get_maximum_possible_lag</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">XYZ</span><span class="p">,</span> <span class="n">graph</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Construct maximum relevant time lag for d-separation in stationary graph.</span>

<span class="sd">        TO BE REVISED!</span>

<span class="sd">        &quot;&quot;&quot;</span>

        <span class="k">def</span> <span class="nf">_repeating</span><span class="p">(</span><span class="n">link</span><span class="p">,</span> <span class="n">seen_path</span><span class="p">):</span>
<span class="w">            </span><span class="sd">&quot;&quot;&quot;Returns True if a link or its time-shifted version is already</span>
<span class="sd">            included in seen_links.&quot;&quot;&quot;</span>
            <span class="n">i</span><span class="p">,</span> <span class="n">taui</span> <span class="o">=</span> <span class="n">link</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
            <span class="n">j</span><span class="p">,</span> <span class="n">tauj</span> <span class="o">=</span> <span class="n">link</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>

            <span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">seen_link</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">seen_path</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]):</span>
                <span class="n">seen_i</span><span class="p">,</span> <span class="n">seen_taui</span> <span class="o">=</span> <span class="n">seen_link</span>
                <span class="n">seen_j</span><span class="p">,</span> <span class="n">seen_tauj</span> <span class="o">=</span> <span class="n">seen_path</span><span class="p">[</span><span class="n">index</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span>

                <span class="k">if</span> <span class="p">(</span><span class="n">i</span> <span class="o">==</span> <span class="n">seen_i</span> <span class="ow">and</span> <span class="n">j</span> <span class="o">==</span> <span class="n">seen_j</span>
                    <span class="ow">and</span> <span class="nb">abs</span><span class="p">(</span><span class="n">tauj</span><span class="o">-</span><span class="n">taui</span><span class="p">)</span> <span class="o">==</span> <span class="nb">abs</span><span class="p">(</span><span class="n">seen_tauj</span><span class="o">-</span><span class="n">seen_taui</span><span class="p">)):</span>
                    <span class="k">return</span> <span class="kc">True</span>

            <span class="k">return</span> <span class="kc">False</span>

        <span class="c1"># TODO: does this work with PAGs?</span>
        <span class="c1"># if self.possible:</span>
        <span class="c1">#     patterns=[&#39;&lt;*-&#39;, &#39;&lt;*o&#39;, &#39;o*o&#39;] </span>
        <span class="c1"># else:</span>
        <span class="c1">#     patterns=[&#39;&lt;*-&#39;] </span>

        <span class="n">canonical_dag_links</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_canonical_dag_from_graph</span><span class="p">(</span><span class="n">graph</span><span class="p">)</span>

        <span class="n">max_lag</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="k">for</span> <span class="n">node</span> <span class="ow">in</span> <span class="n">XYZ</span><span class="p">:</span>
            <span class="n">j</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">node</span>   <span class="c1"># tau &lt;= 0</span>
            <span class="n">max_lag</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">max_lag</span><span class="p">,</span> <span class="nb">abs</span><span class="p">(</span><span class="n">tau</span><span class="p">))</span>

            <span class="n">causal_path</span> <span class="o">=</span> <span class="p">[]</span>
            <span class="n">queue</span> <span class="o">=</span> <span class="p">[(</span><span class="n">node</span><span class="p">,</span> <span class="n">causal_path</span><span class="p">)]</span>

            <span class="k">while</span> <span class="n">queue</span><span class="p">:</span>
                <span class="n">varlag</span><span class="p">,</span> <span class="n">causal_path</span> <span class="o">=</span> <span class="n">queue</span><span class="o">.</span><span class="n">pop</span><span class="p">()</span>
                <span class="n">causal_path</span> <span class="o">=</span> <span class="p">[</span><span class="n">varlag</span><span class="p">]</span> <span class="o">+</span> <span class="n">causal_path</span>

                <span class="n">var</span><span class="p">,</span> <span class="n">lag</span> <span class="o">=</span> <span class="n">varlag</span>
                <span class="k">for</span> <span class="n">partmp</span> <span class="ow">in</span> <span class="n">canonical_dag_links</span><span class="p">[</span><span class="n">var</span><span class="p">]:</span>
                    <span class="n">i</span><span class="p">,</span> <span class="n">tautmp</span> <span class="o">=</span> <span class="n">partmp</span>
                    <span class="c1"># Get shifted lag since canonical_dag_links is at t=0</span>
                    <span class="n">tau</span> <span class="o">=</span> <span class="n">tautmp</span> <span class="o">+</span> <span class="n">lag</span>
                    <span class="n">par</span> <span class="o">=</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">tau</span><span class="p">)</span>

                    <span class="k">if</span> <span class="p">(</span><span class="n">par</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">causal_path</span><span class="p">):</span>
                    
                        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">causal_path</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
                            <span class="n">queue</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">par</span><span class="p">,</span> <span class="n">causal_path</span><span class="p">))</span>
                            <span class="k">continue</span>

                        <span class="k">if</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">causal_path</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">)</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">_repeating</span><span class="p">((</span><span class="n">par</span><span class="p">,</span> <span class="n">varlag</span><span class="p">),</span> <span class="n">causal_path</span><span class="p">):</span>
                            
                                <span class="n">max_lag</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">max_lag</span><span class="p">,</span> <span class="nb">abs</span><span class="p">(</span><span class="n">tau</span><span class="p">))</span>
                                <span class="n">queue</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">par</span><span class="p">,</span> <span class="n">causal_path</span><span class="p">))</span>

        <span class="k">return</span> <span class="n">max_lag</span>

    <span class="k">def</span> <span class="nf">_get_latent_projection_graph</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">stationary</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;For DAGs/ADMGs uses the Latent projection operation (Pearl 2009).</span>

<span class="sd">           Assumes a normal or stationary graph with potentially unobserved nodes.</span>
<span class="sd">           Also allows particular time steps to be unobserved. By stationarity</span>
<span class="sd">           that pattern of unobserved nodes is repeated into -infinity.</span>

<span class="sd">           Latent projection operation for latents = nodes before t-tau_max or due to &lt;-&gt;:</span>
<span class="sd">           (i)  auxADMG contains (i, -taui) --&gt; (j, -tauj) iff there is a directed path </span>
<span class="sd">                (i, -taui) --&gt; ... --&gt; (j, -tauj) on which</span>
<span class="sd">                every non-endpoint vertex is in hidden variables (= not in observed_vars)</span>
<span class="sd">                here iff (i, -|taui-tauj|) --&gt; j in graph</span>
<span class="sd">           (ii) auxADMG contains (i, -taui) &lt;-&gt; (j, -tauj) iff there exists a path of the </span>
<span class="sd">                form (i, -taui) &lt;-- ... --&gt; (j, -tauj) on</span>
<span class="sd">                which every non-endpoint vertex is non-collider AND in L (=not in observed_vars)</span>
<span class="sd">                here iff (i, -|taui-tauj|) &lt;-&gt; j OR there is path </span>
<span class="sd">                (i, -taui) &lt;-- nodes before t-tau_max --&gt; (j, -tauj)</span>
<span class="sd">        &quot;&quot;&quot;</span>
        
        <span class="c1"># graph = self.graph</span>

        <span class="c1"># if self.hidden_variables is None:</span>
        <span class="c1">#     hidden_variables_here = []</span>
        <span class="c1"># else:</span>
        <span class="n">hidden_variables_here</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_variables</span>

        <span class="n">aux_graph</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">N</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">N</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;&lt;U3&#39;</span><span class="p">)</span>
        <span class="n">aux_graph</span><span class="p">[:]</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span>
        <span class="k">for</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">)</span> <span class="ow">in</span> <span class="n">itertools</span><span class="o">.</span><span class="n">product</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">N</span><span class="p">),</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">N</span><span class="p">)):</span>
            <span class="k">for</span> <span class="n">jt</span><span class="p">,</span> <span class="n">tauj</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)):</span>
                <span class="k">for</span> <span class="n">it</span><span class="p">,</span> <span class="n">taui</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)):</span>
                    <span class="n">tau</span> <span class="o">=</span> <span class="nb">abs</span><span class="p">(</span><span class="n">taui</span> <span class="o">-</span> <span class="n">tauj</span><span class="p">)</span>
                    <span class="k">if</span> <span class="n">tau</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">j</span> <span class="o">==</span> <span class="n">i</span><span class="p">:</span>
                        <span class="k">continue</span>
                    <span class="k">if</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="o">-</span><span class="n">taui</span><span class="p">)</span> <span class="ow">in</span> <span class="n">hidden_variables_here</span> <span class="ow">or</span> <span class="p">(</span><span class="n">j</span><span class="p">,</span> <span class="o">-</span><span class="n">tauj</span><span class="p">)</span> <span class="ow">in</span> <span class="n">hidden_variables_here</span><span class="p">:</span>
                        <span class="k">continue</span>
                    <span class="c1"># print(&quot;\n&quot;)</span>
                    <span class="c1"># print((i, -taui), (j, -tauj))</span>

                    <span class="n">cond_i_xy</span> <span class="o">=</span> <span class="p">(</span>
                            <span class="c1"># tau &lt;= graph_taumax </span>
                        <span class="c1"># and (graph[i, j, tau] == &#39;--&gt;&#39; or graph[i, j, tau] == &#39;+-&gt;&#39;) </span>
                        <span class="c1">#     )</span>
                          <span class="c1"># and </span>
                          <span class="bp">self</span><span class="o">.</span><span class="n">_check_path</span><span class="p">(</span> <span class="c1">#graph=graph,</span>
                                                <span class="n">start</span><span class="o">=</span><span class="p">[(</span><span class="n">i</span><span class="p">,</span> <span class="o">-</span><span class="n">taui</span><span class="p">)],</span>
                                                 <span class="n">end</span><span class="o">=</span><span class="p">[(</span><span class="n">j</span><span class="p">,</span> <span class="o">-</span><span class="n">tauj</span><span class="p">)],</span>
                                                 <span class="n">conditions</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
                                                 <span class="n">starts_with</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;-*&gt;&#39;</span><span class="p">,</span> <span class="s1">&#39;+*&gt;&#39;</span><span class="p">],</span>
                                                 <span class="n">ends_with</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;-*&gt;&#39;</span><span class="p">,</span> <span class="s1">&#39;+*&gt;&#39;</span><span class="p">],</span>
                                                 <span class="n">path_type</span><span class="o">=</span><span class="s1">&#39;causal&#39;</span><span class="p">,</span>
                                                 <span class="n">hidden_by_taumax</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
                                                 <span class="n">hidden_variables</span><span class="o">=</span><span class="n">hidden_variables_here</span><span class="p">,</span>
                                                 <span class="n">stationary_graph</span><span class="o">=</span><span class="n">stationary</span><span class="p">,</span>
                                                 <span class="p">))</span>
                    <span class="n">cond_i_yx</span> <span class="o">=</span> <span class="p">(</span>
                        <span class="c1"># tau &lt;= graph_taumax </span>
                        <span class="c1"># and (graph[i, j, tau] == &#39;&lt;--&#39; or graph[i, j, tau] == &#39;&lt;-+&#39;) </span>
                        <span class="c1">#     )</span>
                        <span class="c1"># and </span>
                        <span class="bp">self</span><span class="o">.</span><span class="n">_check_path</span><span class="p">(</span> <span class="c1">#graph=graph,</span>
                                              <span class="n">start</span><span class="o">=</span><span class="p">[(</span><span class="n">j</span><span class="p">,</span> <span class="o">-</span><span class="n">tauj</span><span class="p">)],</span>
                                               <span class="n">end</span><span class="o">=</span><span class="p">[(</span><span class="n">i</span><span class="p">,</span> <span class="o">-</span><span class="n">taui</span><span class="p">)],</span>
                                               <span class="n">conditions</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
                                               <span class="n">starts_with</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;-*&gt;&#39;</span><span class="p">,</span> <span class="s1">&#39;+*&gt;&#39;</span><span class="p">],</span>
                                               <span class="n">ends_with</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;-*&gt;&#39;</span><span class="p">,</span> <span class="s1">&#39;+*&gt;&#39;</span><span class="p">],</span>
                                               <span class="n">path_type</span><span class="o">=</span><span class="s1">&#39;causal&#39;</span><span class="p">,</span>
                                               <span class="n">hidden_by_taumax</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
                                               <span class="n">hidden_variables</span><span class="o">=</span><span class="n">hidden_variables_here</span><span class="p">,</span>
                                               <span class="n">stationary_graph</span><span class="o">=</span><span class="n">stationary</span><span class="p">,</span>
                                               <span class="p">))</span>
                    <span class="k">if</span> <span class="n">stationary</span><span class="p">:</span>
                        <span class="n">hidden_by_taumax_here</span> <span class="o">=</span> <span class="kc">True</span>
                    <span class="k">else</span><span class="p">:</span>
                        <span class="n">hidden_by_taumax_here</span> <span class="o">=</span> <span class="kc">False</span>

                    <span class="n">cond_ii</span> <span class="o">=</span> <span class="p">(</span>
                        <span class="c1"># tau &lt;= graph_taumax </span>
                                <span class="c1"># and </span>
                                <span class="p">(</span>
                                <span class="c1">#     graph[i, j, tau] == &#39;&lt;-&gt;&#39; </span>
                                <span class="c1"># or graph[i, j, tau] == &#39;+-&gt;&#39; or graph[i, j, tau] == &#39;&lt;-+&#39;)) </span>
                                    <span class="bp">self</span><span class="o">.</span><span class="n">_check_path</span><span class="p">(</span> <span class="c1">#graph=graph,</span>
                                                <span class="n">start</span><span class="o">=</span><span class="p">[(</span><span class="n">i</span><span class="p">,</span> <span class="o">-</span><span class="n">taui</span><span class="p">)],</span>
                                                 <span class="n">end</span><span class="o">=</span><span class="p">[(</span><span class="n">j</span><span class="p">,</span> <span class="o">-</span><span class="n">tauj</span><span class="p">)],</span>
                                                 <span class="n">conditions</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
                                                 <span class="n">starts_with</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;&lt;**&#39;</span><span class="p">,</span> <span class="s1">&#39;+**&#39;</span><span class="p">],</span>
                                                 <span class="n">ends_with</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;**&gt;&#39;</span><span class="p">,</span> <span class="s1">&#39;**+&#39;</span><span class="p">],</span>
                                                 <span class="n">path_type</span><span class="o">=</span><span class="s1">&#39;any&#39;</span><span class="p">,</span>
                                                 <span class="n">hidden_by_taumax</span><span class="o">=</span><span class="n">hidden_by_taumax_here</span><span class="p">,</span>
                                                 <span class="n">hidden_variables</span><span class="o">=</span><span class="n">hidden_variables_here</span><span class="p">,</span>
                                                 <span class="n">stationary_graph</span><span class="o">=</span><span class="n">stationary</span><span class="p">,</span>
                                                 <span class="p">)))</span>

                    <span class="k">if</span> <span class="n">cond_i_xy</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">cond_i_yx</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">cond_ii</span><span class="p">:</span>
                        <span class="n">aux_graph</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">taui</span><span class="p">,</span> <span class="n">tauj</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&quot;--&gt;&quot;</span>  <span class="c1">#graph[i, j, tau]</span>
                        <span class="c1"># if tau == 0:</span>
                        <span class="n">aux_graph</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">tauj</span><span class="p">,</span> <span class="n">taui</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&quot;&lt;--&quot;</span>  <span class="c1"># graph[j, i, tau]</span>
                    <span class="k">elif</span> <span class="ow">not</span> <span class="n">cond_i_xy</span> <span class="ow">and</span> <span class="n">cond_i_yx</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">cond_ii</span><span class="p">:</span>
                        <span class="n">aux_graph</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">taui</span><span class="p">,</span> <span class="n">tauj</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&quot;&lt;--&quot;</span>  <span class="c1">#graph[i, j, tau]</span>
                        <span class="c1"># if tau == 0:</span>
                        <span class="n">aux_graph</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">tauj</span><span class="p">,</span> <span class="n">taui</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&quot;--&gt;&quot;</span>  <span class="c1"># graph[j, i, tau]</span>
                    <span class="k">elif</span> <span class="ow">not</span> <span class="n">cond_i_xy</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">cond_i_yx</span> <span class="ow">and</span> <span class="n">cond_ii</span><span class="p">:</span>
                        <span class="n">aux_graph</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">taui</span><span class="p">,</span> <span class="n">tauj</span><span class="p">]</span> <span class="o">=</span> <span class="s1">&#39;&lt;-&gt;&#39;</span>
                        <span class="c1"># if tau == 0:</span>
                        <span class="n">aux_graph</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">tauj</span><span class="p">,</span> <span class="n">taui</span><span class="p">]</span> <span class="o">=</span> <span class="s1">&#39;&lt;-&gt;&#39;</span>
                    <span class="k">elif</span> <span class="n">cond_i_xy</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">cond_i_yx</span> <span class="ow">and</span> <span class="n">cond_ii</span><span class="p">:</span>
                        <span class="n">aux_graph</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">taui</span><span class="p">,</span> <span class="n">tauj</span><span class="p">]</span> <span class="o">=</span> <span class="s1">&#39;+-&gt;&#39;</span>
                        <span class="c1"># if tau == 0:</span>
                        <span class="n">aux_graph</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">tauj</span><span class="p">,</span> <span class="n">taui</span><span class="p">]</span> <span class="o">=</span> <span class="s1">&#39;&lt;-+&#39;</span>                        
                    <span class="k">elif</span> <span class="ow">not</span> <span class="n">cond_i_xy</span> <span class="ow">and</span> <span class="n">cond_i_yx</span> <span class="ow">and</span> <span class="n">cond_ii</span><span class="p">:</span>
                        <span class="n">aux_graph</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">taui</span><span class="p">,</span> <span class="n">tauj</span><span class="p">]</span> <span class="o">=</span> <span class="s1">&#39;&lt;-+&#39;</span>
                        <span class="c1"># if tau == 0:</span>
                        <span class="n">aux_graph</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">tauj</span><span class="p">,</span> <span class="n">taui</span><span class="p">]</span> <span class="o">=</span> <span class="s1">&#39;+-&gt;&#39;</span> 
                    <span class="k">elif</span> <span class="n">cond_i_xy</span> <span class="ow">and</span> <span class="n">cond_i_yx</span><span class="p">:</span>
                        <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Cycle between </span><span class="si">%s</span><span class="s2"> and </span><span class="si">%s</span><span class="s2">!&quot;</span> <span class="o">%</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="o">-</span><span class="n">taui</span><span class="p">),</span> <span class="nb">str</span><span class="p">(</span><span class="n">j</span><span class="p">,</span> <span class="o">-</span><span class="n">tauj</span><span class="p">)))</span>
                    <span class="c1"># print(aux_graph[i, j, taui, tauj])</span>

                    <span class="c1"># print((i, -taui), (j, -tauj), cond_i_xy, cond_i_yx, cond_ii, aux_graph[i, j, taui, tauj], aux_graph[j, i, tauj, taui])</span>

        <span class="k">return</span> <span class="n">aux_graph</span>

    <span class="k">def</span> <span class="nf">_check_path</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> 
        <span class="c1"># graph, </span>
        <span class="n">start</span><span class="p">,</span> <span class="n">end</span><span class="p">,</span>
        <span class="n">conditions</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> 
        <span class="n">starts_with</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
        <span class="n">ends_with</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
        <span class="n">path_type</span><span class="o">=</span><span class="s1">&#39;any&#39;</span><span class="p">,</span>
        <span class="c1"># causal_children=None,</span>
        <span class="n">stationary_graph</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
        <span class="n">hidden_by_taumax</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
        <span class="n">hidden_variables</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
        <span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Check whether an open/active path between start and end given conditions exists.</span>
<span class="sd">        </span>
<span class="sd">        Also allows to restrict start and end patterns and to consider causal/non-causal paths</span>

<span class="sd">        hidden_by_taumax and hidden_variables are relevant for the latent projection operation.</span>
<span class="sd">        &quot;&quot;&quot;</span>


        <span class="k">if</span> <span class="n">conditions</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">conditions</span> <span class="o">=</span> <span class="nb">set</span><span class="p">([])</span>
        <span class="c1"># if conditioned_variables is None:</span>
        <span class="c1">#     S = []</span>

        <span class="n">start</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">start</span><span class="p">)</span>
        <span class="n">end</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">end</span><span class="p">)</span>
        <span class="n">conditions</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">conditions</span><span class="p">)</span>
        
        <span class="c1"># Get maximal possible time lag of a connecting path</span>
        <span class="c1"># See Thm. XXXX - TO BE REVISED!</span>
        <span class="n">XYZ</span> <span class="o">=</span> <span class="n">start</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">end</span><span class="p">)</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">conditions</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">stationary_graph</span><span class="p">:</span>
            <span class="n">max_lag</span> <span class="o">=</span> <span class="mi">10</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span>  <span class="c1"># TO BE REVISED! self._get_maximum_possible_lag(XYZ, self.graph)</span>
            <span class="n">causal_children</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_mediators_stationary_graph</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="n">end</span><span class="p">,</span> <span class="n">max_lag</span><span class="p">)</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">end</span><span class="p">))</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">max_lag</span> <span class="o">=</span> <span class="kc">None</span>
            <span class="n">causal_children</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_mediators</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="n">end</span><span class="p">)</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">end</span><span class="p">))</span>
       
        <span class="c1"># if hidden_variables is None:</span>
        <span class="c1">#     hidden_variables = set([])</span>

        <span class="k">if</span> <span class="n">hidden_by_taumax</span><span class="p">:</span>
            <span class="k">if</span> <span class="n">hidden_variables</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
                <span class="n">hidden_variables</span> <span class="o">=</span> <span class="nb">set</span><span class="p">([])</span>
            <span class="n">hidden_variables</span> <span class="o">=</span> <span class="n">hidden_variables</span><span class="o">.</span><span class="n">union</span><span class="p">([(</span><span class="n">k</span><span class="p">,</span> <span class="o">-</span><span class="n">tauk</span><span class="p">)</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">N</span><span class="p">)</span> 
                                            <span class="k">for</span> <span class="n">tauk</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">max_lag</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)])</span>

        <span class="c1"># print(&quot;causal_children &quot;, causal_children)</span>

        <span class="k">if</span> <span class="n">starts_with</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">starts_with</span> <span class="o">=</span> <span class="p">[</span><span class="s1">&#39;***&#39;</span><span class="p">]</span>
        <span class="k">elif</span> <span class="nb">type</span><span class="p">(</span><span class="n">starts_with</span><span class="p">)</span> <span class="o">==</span> <span class="nb">str</span><span class="p">:</span>
            <span class="n">starts_with</span> <span class="o">=</span> <span class="p">[</span><span class="n">starts_with</span><span class="p">]</span>

        <span class="k">if</span> <span class="n">ends_with</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">ends_with</span> <span class="o">=</span> <span class="p">[</span><span class="s1">&#39;***&#39;</span><span class="p">]</span>
        <span class="k">elif</span> <span class="nb">type</span><span class="p">(</span><span class="n">ends_with</span><span class="p">)</span> <span class="o">==</span> <span class="nb">str</span><span class="p">:</span>
            <span class="n">ends_with</span> <span class="o">=</span> <span class="p">[</span><span class="n">ends_with</span><span class="p">]</span>
        <span class="c1">#</span>
        <span class="c1"># Breadth-first search to find connection</span>
        <span class="c1">#</span>
        <span class="c1"># print(&quot;\nstart, starts_with, ends_with, end &quot;, start, starts_with, ends_with, end)</span>
        <span class="c1"># print(&quot;hidden_variables &quot;, hidden_variables)</span>
        <span class="n">start_from</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>
        <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">start</span><span class="p">:</span>
            <span class="k">if</span> <span class="n">stationary_graph</span><span class="p">:</span>
                <span class="n">link_neighbors</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_adjacents_stationary_graph</span><span class="p">(</span><span class="n">graph</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">graph</span><span class="p">,</span> <span class="n">node</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">patterns</span><span class="o">=</span><span class="n">starts_with</span><span class="p">,</span> 
                                        <span class="n">max_lag</span><span class="o">=</span><span class="n">max_lag</span><span class="p">,</span> <span class="n">exclude</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="n">start</span><span class="p">))</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="n">link_neighbors</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_find_adj</span><span class="p">(</span><span class="n">node</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">patterns</span><span class="o">=</span><span class="n">starts_with</span><span class="p">,</span> <span class="n">exclude</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="n">start</span><span class="p">),</span> <span class="n">return_link</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
            
            <span class="k">for</span> <span class="n">link_neighbor</span> <span class="ow">in</span> <span class="n">link_neighbors</span><span class="p">:</span>
                <span class="n">link</span><span class="p">,</span> <span class="n">neighbor</span> <span class="o">=</span> <span class="n">link_neighbor</span>

                <span class="c1"># if before_taumax and neighbor[1] &gt;= -self.tau_max:</span>
                <span class="c1">#     continue</span>

                <span class="k">if</span> <span class="p">(</span><span class="n">hidden_variables</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">neighbor</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">end</span>
                                    <span class="ow">and</span> <span class="n">neighbor</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">hidden_variables</span><span class="p">):</span>
                    <span class="k">continue</span>

                <span class="k">if</span> <span class="n">path_type</span> <span class="o">==</span> <span class="s1">&#39;non_causal&#39;</span><span class="p">:</span>
                    <span class="k">if</span> <span class="p">(</span><span class="n">neighbor</span> <span class="ow">in</span> <span class="n">causal_children</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">_match_link</span><span class="p">(</span><span class="s1">&#39;-*&gt;&#39;</span><span class="p">,</span> <span class="n">link</span><span class="p">)</span> 
                        <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_match_link</span><span class="p">(</span><span class="s1">&#39;+*&gt;&#39;</span><span class="p">,</span> <span class="n">link</span><span class="p">)):</span>
                        <span class="k">continue</span>
                <span class="k">elif</span> <span class="n">path_type</span> <span class="o">==</span> <span class="s1">&#39;causal&#39;</span><span class="p">:</span>
                    <span class="k">if</span> <span class="p">(</span><span class="n">neighbor</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">causal_children</span><span class="p">):</span> <span class="c1"># or self._match_link(&#39;&lt;**&#39;, link)):</span>
                        <span class="k">continue</span>                    
                <span class="n">start_from</span><span class="o">.</span><span class="n">add</span><span class="p">((</span><span class="n">x</span><span class="p">,</span> <span class="n">link</span><span class="p">,</span> <span class="n">neighbor</span><span class="p">))</span>

        <span class="c1"># print(&quot;start, end, start_from &quot;, start, end, start_from)</span>

        <span class="n">visited</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>
        <span class="k">for</span> <span class="p">(</span><span class="n">varlag_i</span><span class="p">,</span> <span class="n">link_ik</span><span class="p">,</span> <span class="n">varlag_k</span><span class="p">)</span> <span class="ow">in</span> <span class="n">start_from</span><span class="p">:</span>
            <span class="n">visited</span><span class="o">.</span><span class="n">add</span><span class="p">((</span><span class="n">link_ik</span><span class="p">,</span> <span class="n">varlag_k</span><span class="p">))</span>

        <span class="c1"># Traversing through motifs i *-* k *-* j</span>
        <span class="k">while</span> <span class="n">start_from</span><span class="p">:</span>

            <span class="c1"># print(&quot;Continue &quot;, start_from)</span>
            <span class="c1"># for (link_ik, varlag_k) in start_from:</span>
            <span class="n">removables</span> <span class="o">=</span> <span class="p">[]</span>
            <span class="k">for</span> <span class="p">(</span><span class="n">varlag_i</span><span class="p">,</span> <span class="n">link_ik</span><span class="p">,</span> <span class="n">varlag_k</span><span class="p">)</span> <span class="ow">in</span> <span class="n">start_from</span><span class="p">:</span>

                <span class="c1"># print(&quot;varlag_k in end &quot;, varlag_k in end, link_ik)</span>
                <span class="k">if</span> <span class="n">varlag_k</span> <span class="ow">in</span> <span class="n">end</span><span class="p">:</span>
                    <span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">_match_link</span><span class="p">(</span><span class="n">patt</span><span class="p">,</span> <span class="n">link_ik</span><span class="p">)</span> <span class="k">for</span> <span class="n">patt</span> <span class="ow">in</span> <span class="n">ends_with</span><span class="p">]):</span>
                        <span class="c1"># print(&quot;Connected &quot;, varlag_i, link_ik, varlag_k)</span>
                        <span class="k">return</span> <span class="kc">True</span>
                    <span class="k">else</span><span class="p">:</span>
                        <span class="n">removables</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">varlag_i</span><span class="p">,</span> <span class="n">link_ik</span><span class="p">,</span> <span class="n">varlag_k</span><span class="p">))</span>

            <span class="k">for</span> <span class="n">removable</span> <span class="ow">in</span> <span class="n">removables</span><span class="p">:</span>
                <span class="n">start_from</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="n">removable</span><span class="p">)</span>
            <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">start_from</span><span class="p">)</span><span class="o">==</span><span class="mi">0</span><span class="p">:</span>
                <span class="k">return</span> <span class="kc">False</span>

            <span class="c1"># Get any neighbor from starting nodes</span>
            <span class="c1"># link_ik, varlag_k = start_from.pop()</span>
            <span class="n">varlag_i</span><span class="p">,</span> <span class="n">link_ik</span><span class="p">,</span> <span class="n">varlag_k</span> <span class="o">=</span> <span class="n">start_from</span><span class="o">.</span><span class="n">pop</span><span class="p">()</span>

            <span class="c1"># print(&quot;Get k = &quot;, link_ik, varlag_k)</span>
            <span class="c1"># print(&quot;start_from &quot;, start_from)</span>
            <span class="c1"># print(&quot;visited    &quot;, visited)</span>

            <span class="k">if</span> <span class="n">stationary_graph</span><span class="p">:</span>
                <span class="n">link_neighbors</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_adjacents_stationary_graph</span><span class="p">(</span><span class="n">graph</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">graph</span><span class="p">,</span> <span class="n">node</span><span class="o">=</span><span class="n">varlag_k</span><span class="p">,</span> <span class="n">patterns</span><span class="o">=</span><span class="s1">&#39;***&#39;</span><span class="p">,</span> 
                                        <span class="n">max_lag</span><span class="o">=</span><span class="n">max_lag</span><span class="p">,</span> <span class="n">exclude</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="n">start</span><span class="p">))</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="n">link_neighbors</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_find_adj</span><span class="p">(</span><span class="n">node</span><span class="o">=</span><span class="n">varlag_k</span><span class="p">,</span> <span class="n">patterns</span><span class="o">=</span><span class="s1">&#39;***&#39;</span><span class="p">,</span> <span class="n">exclude</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="n">start</span><span class="p">),</span> <span class="n">return_link</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
            
            <span class="c1"># print(&quot;link_neighbors &quot;, link_neighbors)</span>
            <span class="k">for</span> <span class="n">link_neighbor</span> <span class="ow">in</span> <span class="n">link_neighbors</span><span class="p">:</span>
                <span class="n">link_kj</span><span class="p">,</span> <span class="n">varlag_j</span> <span class="o">=</span> <span class="n">link_neighbor</span>
                <span class="c1"># print(&quot;Walk &quot;, link_ik, varlag_k, link_kj, varlag_j)</span>

                <span class="c1"># print (&quot;visited &quot;, (link_kj, varlag_j), visited)</span>
                <span class="k">if</span> <span class="p">(</span><span class="n">link_kj</span><span class="p">,</span> <span class="n">varlag_j</span><span class="p">)</span> <span class="ow">in</span> <span class="n">visited</span><span class="p">:</span>
                <span class="c1"># if (varlag_i, link_kj, varlag_j) in visited:</span>
                    <span class="c1"># print(&quot;in visited&quot;)</span>
                    <span class="k">continue</span>
                <span class="c1"># print(&quot;Not in visited&quot;)</span>

                <span class="k">if</span> <span class="n">path_type</span> <span class="o">==</span> <span class="s1">&#39;causal&#39;</span><span class="p">:</span>
                    <span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_match_link</span><span class="p">(</span><span class="s1">&#39;-*&gt;&#39;</span><span class="p">,</span> <span class="n">link_kj</span><span class="p">)</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">_match_link</span><span class="p">(</span><span class="s1">&#39;+*&gt;&#39;</span><span class="p">,</span> <span class="n">link_kj</span><span class="p">)):</span>
                        <span class="k">continue</span> 

                <span class="c1"># If motif  i *-* k *-* j is open, </span>
                <span class="c1"># then add link_kj, varlag_j to visited and start_from</span>
                <span class="n">left_mark</span> <span class="o">=</span> <span class="n">link_ik</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
                <span class="n">right_mark</span> <span class="o">=</span> <span class="n">link_kj</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
                <span class="c1"># print(left_mark, right_mark)</span>

                <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">definite_status</span><span class="p">:</span>
                    <span class="c1"># Exclude paths that are not definite_status implying that any of the following</span>
                    <span class="c1"># motifs occurs:</span>
                    <span class="c1"># i *-&gt; k o-* j</span>
                    <span class="k">if</span> <span class="p">(</span><span class="n">left_mark</span> <span class="o">==</span> <span class="s1">&#39;&gt;&#39;</span> <span class="ow">and</span> <span class="n">right_mark</span> <span class="o">==</span> <span class="s1">&#39;o&#39;</span><span class="p">):</span>
                        <span class="k">continue</span>
                    <span class="c1"># i *-o k &lt;-* j</span>
                    <span class="k">if</span> <span class="p">(</span><span class="n">left_mark</span> <span class="o">==</span> <span class="s1">&#39;o&#39;</span> <span class="ow">and</span> <span class="n">right_mark</span> <span class="o">==</span> <span class="s1">&#39;&lt;&#39;</span><span class="p">):</span>
                        <span class="k">continue</span>
                    <span class="c1"># i *-o k o-* j and i and j are adjacent</span>
                    <span class="k">if</span> <span class="p">(</span><span class="n">left_mark</span> <span class="o">==</span> <span class="s1">&#39;o&#39;</span> <span class="ow">and</span> <span class="n">right_mark</span> <span class="o">==</span> <span class="s1">&#39;o&#39;</span>
                        <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">_is_match</span><span class="p">(</span><span class="n">varlag_i</span><span class="p">,</span> <span class="n">varlag_j</span><span class="p">,</span> <span class="s2">&quot;***&quot;</span><span class="p">)):</span>
                        <span class="k">continue</span>

                    <span class="c1"># If k is in conditions and motif is *-o k o-*, then motif is blocked since</span>
                    <span class="c1"># i and j are non-adjacent due to the check above</span>
                    <span class="k">if</span> <span class="n">varlag_k</span> <span class="ow">in</span> <span class="n">conditions</span> <span class="ow">and</span> <span class="p">(</span><span class="n">left_mark</span> <span class="o">==</span> <span class="s1">&#39;o&#39;</span> <span class="ow">and</span> <span class="n">right_mark</span> <span class="o">==</span> <span class="s1">&#39;o&#39;</span><span class="p">):</span>
                        <span class="c1"># print(&quot;Motif closed &quot;, link_ik, varlag_k, link_kj, varlag_j )</span>
                        <span class="k">continue</span>  <span class="c1"># [(&#39;&gt;&#39;, &#39;&lt;&#39;), (&#39;&gt;&#39;, &#39;+&#39;), (&#39;+&#39;, &#39;&lt;&#39;), (&#39;+&#39;, &#39;+&#39;)]</span>

                <span class="c1"># If k is in conditions and left or right mark is tail &#39;-&#39;, then motif is blocked</span>
                <span class="k">if</span> <span class="n">varlag_k</span> <span class="ow">in</span> <span class="n">conditions</span> <span class="ow">and</span> <span class="p">(</span><span class="n">left_mark</span> <span class="o">==</span> <span class="s1">&#39;-&#39;</span> <span class="ow">or</span> <span class="n">right_mark</span> <span class="o">==</span> <span class="s1">&#39;-&#39;</span><span class="p">):</span>
                    <span class="c1"># print(&quot;Motif closed &quot;, link_ik, varlag_k, link_kj, varlag_j )</span>
                    <span class="k">continue</span>  <span class="c1"># [(&#39;&gt;&#39;, &#39;&lt;&#39;), (&#39;&gt;&#39;, &#39;+&#39;), (&#39;+&#39;, &#39;&lt;&#39;), (&#39;+&#39;, &#39;+&#39;)]</span>

                <span class="c1"># If k is not in conditions and left and right mark are heads &#39;&gt;&lt;&#39;, then motif is blocked</span>
                <span class="k">if</span> <span class="n">varlag_k</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">conditions</span> <span class="ow">and</span> <span class="p">(</span><span class="n">left_mark</span> <span class="o">==</span> <span class="s1">&#39;&gt;&#39;</span> <span class="ow">and</span> <span class="n">right_mark</span> <span class="o">==</span> <span class="s1">&#39;&lt;&#39;</span><span class="p">):</span>
                    <span class="c1"># print(&quot;Motif closed &quot;, link_ik, varlag_k, link_kj, varlag_j )</span>
                    <span class="k">continue</span>  <span class="c1"># [(&#39;&gt;&#39;, &#39;&lt;&#39;), (&#39;&gt;&#39;, &#39;+&#39;), (&#39;+&#39;, &#39;&lt;&#39;), (&#39;+&#39;, &#39;+&#39;)]</span>

                <span class="c1"># if (before_taumax and varlag_j not in end </span>
                <span class="c1">#     and varlag_j[1] &gt;= -self.tau_max):</span>
                <span class="c1">#     # print(&quot;before_taumax &quot;, varlag_j)</span>
                <span class="c1">#     continue</span>

                <span class="k">if</span> <span class="p">(</span><span class="n">hidden_variables</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">varlag_j</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">end</span>
                                    <span class="ow">and</span> <span class="n">varlag_j</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">hidden_variables</span><span class="p">):</span>
                    <span class="k">continue</span>

                <span class="c1"># Motif is open</span>
                <span class="c1"># print(&quot;Motif open &quot;, link_ik, varlag_k, link_kj, varlag_j )</span>
                <span class="c1"># start_from.add((link_kj, varlag_j))</span>
                <span class="n">visited</span><span class="o">.</span><span class="n">add</span><span class="p">((</span><span class="n">link_kj</span><span class="p">,</span> <span class="n">varlag_j</span><span class="p">))</span>
                <span class="n">start_from</span><span class="o">.</span><span class="n">add</span><span class="p">((</span><span class="n">varlag_k</span><span class="p">,</span> <span class="n">link_kj</span><span class="p">,</span> <span class="n">varlag_j</span><span class="p">))</span>
                <span class="c1"># visited.add((varlag_k, link_kj, varlag_j))</span>


        <span class="c1"># print(&quot;Separated&quot;)</span>
        <span class="k">return</span> <span class="kc">False</span>

<div class="viewcode-block" id="CausalEffects.get_optimal_set"><a class="viewcode-back" href="../../index.html#tigramite.causal_effects.CausalEffects.get_optimal_set">[docs]</a>    <span class="k">def</span> <span class="nf">get_optimal_set</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> 
        <span class="n">alternative_conditions</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
        <span class="n">minimize</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
        <span class="n">return_separate_sets</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
        <span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Returns optimal adjustment set.</span>
<span class="sd">        </span>
<span class="sd">        See Runge NeurIPS 2021.</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        alternative_conditions : set of tuples</span>
<span class="sd">            Used only internally in optimality theorem. If None, self.S is used.</span>
<span class="sd">        minimize : {False, True, &#39;colliders_only&#39;} </span>
<span class="sd">            Minimize optimal set. If True, minimize such that no subset </span>
<span class="sd">            can be removed without making it invalid. If &#39;colliders_only&#39;,</span>
<span class="sd">            only colliders are minimized.</span>
<span class="sd">        return_separate_sets : bool</span>
<span class="sd">            Whether to return tuple of parents, colliders, collider_parents, and S.</span>
<span class="sd">        </span>
<span class="sd">        Returns</span>
<span class="sd">        -------</span>
<span class="sd">        Oset_S : False or list or tuple of lists</span>
<span class="sd">            Returns optimal adjustment set if a valid set exists, otherwise False.</span>
<span class="sd">        &quot;&quot;&quot;</span>


        <span class="c1"># Needed for optimality theorem where Osets for alternative S are tested</span>
        <span class="k">if</span> <span class="n">alternative_conditions</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">S</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">S</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
            <span class="n">vancs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">vancs</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">S</span> <span class="o">=</span> <span class="n">alternative_conditions</span>
            <span class="n">newancS</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_ancestors</span><span class="p">(</span><span class="n">S</span><span class="p">)</span>
            <span class="n">vancs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ancX</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ancY</span><span class="p">)</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">newancS</span><span class="p">)</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">forbidden_nodes</span>

            <span class="c1"># vancs = self._get_ancestors(list(self.X.union(self.Y).union(S))) - self.forbidden_nodes</span>

        <span class="c1"># descendants = self._get_descendants(self.Y.union(self.M))</span>

        <span class="c1"># Sufficient condition for non-identifiability</span>
        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="o">.</span><span class="n">intersection</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">descendants</span><span class="p">))</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="k">return</span> <span class="kc">False</span>  <span class="c1"># raise ValueError(&quot;Not identifiable: Overlap between X and des(M)&quot;)</span>

        <span class="c1">##</span>
        <span class="c1">## Construct O-set</span>
        <span class="c1">##</span>

        <span class="c1"># Start with parents </span>
        <span class="n">parents</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_all_parents</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">M</span><span class="p">))</span> <span class="c1"># set([])</span>

        <span class="c1"># Remove forbidden nodes</span>
        <span class="n">parents</span> <span class="o">=</span> <span class="n">parents</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">forbidden_nodes</span>

        <span class="c1"># Construct valid collider path nodes</span>
        <span class="n">colliders</span> <span class="o">=</span> <span class="nb">set</span><span class="p">([])</span>
        <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">M</span><span class="p">):</span>
            <span class="n">j</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">w</span> 
            <span class="n">this_level</span> <span class="o">=</span> <span class="p">[</span><span class="n">w</span><span class="p">]</span>
            <span class="n">non_suitable_nodes</span> <span class="o">=</span> <span class="p">[]</span>
            <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">this_level</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="n">next_level</span> <span class="o">=</span> <span class="p">[]</span>
                <span class="k">for</span> <span class="n">varlag</span> <span class="ow">in</span> <span class="n">this_level</span><span class="p">:</span>
                    <span class="n">suitable_spouses</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_spouses</span><span class="p">(</span><span class="n">varlag</span><span class="p">))</span> <span class="o">-</span> <span class="nb">set</span><span class="p">(</span><span class="n">non_suitable_nodes</span><span class="p">)</span>
                    <span class="k">for</span> <span class="n">spouse</span> <span class="ow">in</span> <span class="n">suitable_spouses</span><span class="p">:</span>
                        <span class="n">i</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">spouse</span>
                        <span class="k">if</span> <span class="n">spouse</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">:</span>
                            <span class="k">return</span> <span class="kc">False</span>

                        <span class="k">if</span> <span class="p">(</span><span class="c1"># Node not already in set</span>
                            <span class="n">spouse</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">colliders</span>  <span class="c1">#.union(parents)</span>
                            <span class="c1"># not forbidden</span>
                            <span class="ow">and</span> <span class="n">spouse</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">forbidden_nodes</span> 
                            <span class="c1"># in time bounds</span>
                            <span class="ow">and</span> <span class="p">(</span><span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">&lt;=</span> <span class="n">tau</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="c1"># or self.ignore_time_bounds)</span>
                            <span class="ow">and</span> <span class="p">(</span><span class="n">spouse</span> <span class="ow">in</span> <span class="n">vancs</span>
                                <span class="ow">or</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_path</span><span class="p">(</span><span class="c1">#graph=self.graph, </span>
                                    <span class="n">start</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">,</span> <span class="n">end</span><span class="o">=</span><span class="p">[</span><span class="n">spouse</span><span class="p">],</span> 
                                                    <span class="n">conditions</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="n">parents</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">vancs</span><span class="p">))</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="n">S</span><span class="p">),</span>
                                                    <span class="p">))</span>
                                <span class="p">):</span>
                                <span class="n">colliders</span> <span class="o">=</span> <span class="n">colliders</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">spouse</span><span class="p">]))</span>
                                <span class="n">next_level</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">spouse</span><span class="p">)</span>
                        <span class="k">else</span><span class="p">:</span>
                            <span class="k">if</span> <span class="n">spouse</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">colliders</span><span class="p">:</span>
                                <span class="n">non_suitable_nodes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">spouse</span><span class="p">)</span>


                <span class="n">this_level</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">next_level</span><span class="p">)</span> <span class="o">-</span> <span class="nb">set</span><span class="p">(</span><span class="n">non_suitable_nodes</span><span class="p">)</span>  

        <span class="c1"># Add parents and raise Error if not identifiable</span>
        <span class="n">collider_parents</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_all_parents</span><span class="p">(</span><span class="n">colliders</span><span class="p">)</span>
        <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="o">.</span><span class="n">intersection</span><span class="p">(</span><span class="n">collider_parents</span><span class="p">))</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="k">return</span> <span class="kc">False</span>

        <span class="n">colliders_and_their_parents</span> <span class="o">=</span> <span class="n">colliders</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">collider_parents</span><span class="p">)</span>

        <span class="c1"># Add valid collider path nodes and their parents</span>
        <span class="n">Oset</span> <span class="o">=</span> <span class="n">parents</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">colliders_and_their_parents</span><span class="p">)</span>


        <span class="k">if</span> <span class="n">minimize</span><span class="p">:</span> 
            <span class="n">removable</span> <span class="o">=</span> <span class="p">[]</span>
            <span class="c1"># First remove all those that have no path from X</span>
            <span class="n">sorted_Oset</span> <span class="o">=</span>  <span class="n">Oset</span>
            <span class="k">if</span> <span class="n">minimize</span> <span class="o">==</span> <span class="s1">&#39;colliders_only&#39;</span><span class="p">:</span>
                <span class="n">sorted_Oset</span> <span class="o">=</span> <span class="p">[</span><span class="n">node</span> <span class="k">for</span> <span class="n">node</span> <span class="ow">in</span> <span class="n">sorted_Oset</span> <span class="k">if</span> <span class="n">node</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">parents</span><span class="p">]</span>

            <span class="k">for</span> <span class="n">node</span> <span class="ow">in</span> <span class="n">sorted_Oset</span><span class="p">:</span>
                <span class="k">if</span> <span class="p">(</span><span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_path</span><span class="p">(</span><span class="c1">#graph=self.graph, </span>
                    <span class="n">start</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">,</span> <span class="n">end</span><span class="o">=</span><span class="p">[</span><span class="n">node</span><span class="p">],</span> 
                                <span class="n">conditions</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="n">Oset</span> <span class="o">-</span> <span class="nb">set</span><span class="p">([</span><span class="n">node</span><span class="p">]))</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="n">S</span><span class="p">))):</span>
                    <span class="n">removable</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">node</span><span class="p">)</span> 

            <span class="n">Oset</span> <span class="o">=</span> <span class="n">Oset</span> <span class="o">-</span> <span class="nb">set</span><span class="p">(</span><span class="n">removable</span><span class="p">)</span>
            <span class="k">if</span> <span class="n">minimize</span> <span class="o">==</span> <span class="s1">&#39;colliders_only&#39;</span><span class="p">:</span>
                <span class="n">sorted_Oset</span> <span class="o">=</span> <span class="p">[</span><span class="n">node</span> <span class="k">for</span> <span class="n">node</span> <span class="ow">in</span> <span class="n">Oset</span> <span class="k">if</span> <span class="n">node</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">parents</span><span class="p">]</span>

            <span class="n">removable</span> <span class="o">=</span> <span class="p">[]</span>
            <span class="c1"># Next remove all those with no direct connection to Y</span>
            <span class="k">for</span> <span class="n">node</span> <span class="ow">in</span> <span class="n">sorted_Oset</span><span class="p">:</span>
                <span class="k">if</span> <span class="p">(</span><span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_path</span><span class="p">(</span><span class="c1">#graph=self.graph, </span>
                    <span class="n">start</span><span class="o">=</span><span class="p">[</span><span class="n">node</span><span class="p">],</span> <span class="n">end</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="p">,</span> 
                            <span class="n">conditions</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="n">Oset</span> <span class="o">-</span> <span class="nb">set</span><span class="p">([</span><span class="n">node</span><span class="p">]))</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="n">S</span><span class="p">)</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">),</span>
                            <span class="n">ends_with</span><span class="o">=</span><span class="p">[</span><span class="s1">&#39;**&gt;&#39;</span><span class="p">,</span> <span class="s1">&#39;**+&#39;</span><span class="p">])):</span> 
                    <span class="n">removable</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">node</span><span class="p">)</span> 

            <span class="n">Oset</span> <span class="o">=</span> <span class="n">Oset</span> <span class="o">-</span> <span class="nb">set</span><span class="p">(</span><span class="n">removable</span><span class="p">)</span>

        <span class="n">Oset_S</span> <span class="o">=</span> <span class="n">Oset</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">S</span><span class="p">)</span>

        <span class="k">if</span> <span class="n">return_separate_sets</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">parents</span><span class="p">,</span> <span class="n">colliders</span><span class="p">,</span> <span class="n">collider_parents</span><span class="p">,</span> <span class="n">S</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="n">Oset_S</span><span class="p">)</span></div>


    <span class="k">def</span> <span class="nf">_get_collider_paths_optimality</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">source_nodes</span><span class="p">,</span> <span class="n">target_nodes</span><span class="p">,</span>
        <span class="n">condition</span><span class="p">,</span> 
        <span class="n">inside_set</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> 
        <span class="n">start_with_tail_or_head</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> 
        <span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Returns relevant collider paths to check optimality.</span>

<span class="sd">        Iterates over collider paths within O-set via depth-first search</span>

<span class="sd">        &quot;&quot;&quot;</span>

        <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">source_nodes</span><span class="p">:</span>
            <span class="c1"># Only used to return *all* collider paths </span>
            <span class="c1"># (needed in optimality theorem)</span>
            
            <span class="n">coll_path</span> <span class="o">=</span> <span class="p">[]</span>

            <span class="n">queue</span> <span class="o">=</span> <span class="p">[(</span><span class="n">w</span><span class="p">,</span> <span class="n">coll_path</span><span class="p">)]</span>

            <span class="n">non_valid_subsets</span> <span class="o">=</span> <span class="p">[]</span>

            <span class="k">while</span> <span class="n">queue</span><span class="p">:</span>

                <span class="n">varlag</span><span class="p">,</span> <span class="n">coll_path</span> <span class="o">=</span> <span class="n">queue</span><span class="o">.</span><span class="n">pop</span><span class="p">()</span>

                <span class="n">coll_path</span> <span class="o">=</span> <span class="n">coll_path</span> <span class="o">+</span> <span class="p">[</span><span class="n">varlag</span><span class="p">]</span>

                <span class="n">suitable_nodes</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_spouses</span><span class="p">(</span><span class="n">varlag</span><span class="p">))</span>

                <span class="k">if</span> <span class="n">start_with_tail_or_head</span> <span class="ow">and</span> <span class="n">coll_path</span> <span class="o">==</span> <span class="p">[</span><span class="n">w</span><span class="p">]:</span>
                    <span class="n">children</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_children</span><span class="p">(</span><span class="n">varlag</span><span class="p">))</span>
                    <span class="n">suitable_nodes</span> <span class="o">=</span> <span class="n">suitable_nodes</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">children</span><span class="p">)</span>
 
                <span class="k">for</span> <span class="n">node</span> <span class="ow">in</span> <span class="n">suitable_nodes</span><span class="p">:</span>
                    <span class="n">i</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">node</span>
                    <span class="k">if</span> <span class="p">((</span><span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">&lt;=</span> <span class="n">tau</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="c1"># or self.ignore_time_bounds)</span>
                        <span class="ow">and</span> <span class="n">node</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">coll_path</span><span class="p">):</span>

                        <span class="k">if</span> <span class="n">condition</span> <span class="o">==</span> <span class="s1">&#39;II&#39;</span> <span class="ow">and</span> <span class="n">node</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">target_nodes</span> <span class="ow">and</span> <span class="n">node</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">vancs</span><span class="p">:</span>
                            <span class="k">continue</span>

                        <span class="k">if</span> <span class="n">node</span> <span class="ow">in</span> <span class="n">inside_set</span><span class="p">:</span>
                            <span class="k">if</span> <span class="n">condition</span> <span class="o">==</span> <span class="s1">&#39;I&#39;</span><span class="p">:</span>
                                <span class="n">non_valid</span> <span class="o">=</span> <span class="kc">False</span>
                                <span class="k">for</span> <span class="n">pathset</span> <span class="ow">in</span> <span class="n">non_valid_subsets</span><span class="p">[::</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span>
                                    <span class="k">if</span> <span class="nb">set</span><span class="p">(</span><span class="n">pathset</span><span class="p">)</span><span class="o">.</span><span class="n">issubset</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">coll_path</span> <span class="o">+</span> <span class="p">[</span><span class="n">node</span><span class="p">])):</span>
                                        <span class="n">non_valid</span> <span class="o">=</span> <span class="kc">True</span>
                                        <span class="k">break</span>
                                <span class="k">if</span> <span class="n">non_valid</span> <span class="ow">is</span> <span class="kc">False</span><span class="p">:</span>
                                    <span class="n">queue</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">node</span><span class="p">,</span> <span class="n">coll_path</span><span class="p">))</span> 
                                <span class="k">else</span><span class="p">:</span>
                                    <span class="k">continue</span>
                            <span class="k">elif</span> <span class="n">condition</span> <span class="o">==</span> <span class="s1">&#39;II&#39;</span><span class="p">:</span>
                                <span class="n">queue</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">node</span><span class="p">,</span> <span class="n">coll_path</span><span class="p">))</span>

                        <span class="k">if</span> <span class="n">node</span> <span class="ow">in</span> <span class="n">target_nodes</span><span class="p">:</span>  
                            <span class="c1"># yield coll_path</span>
                            <span class="c1"># collider_paths[node].append(coll_path) </span>
                            <span class="k">if</span> <span class="n">condition</span> <span class="o">==</span> <span class="s1">&#39;I&#39;</span><span class="p">:</span>         
                                <span class="c1"># Construct OπiN</span>
                                <span class="n">Sprime</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">S</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">coll_path</span><span class="p">)</span>
                                <span class="n">OpiN</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_optimal_set</span><span class="p">(</span><span class="n">alternative_conditions</span><span class="o">=</span><span class="n">Sprime</span><span class="p">)</span>
                                <span class="k">if</span> <span class="n">OpiN</span> <span class="ow">is</span> <span class="kc">False</span><span class="p">:</span>
                                    <span class="n">queue</span> <span class="o">=</span> <span class="p">[(</span><span class="n">q_node</span><span class="p">,</span> <span class="n">q_path</span><span class="p">)</span> <span class="k">for</span> <span class="p">(</span><span class="n">q_node</span><span class="p">,</span> <span class="n">q_path</span><span class="p">)</span> <span class="ow">in</span> <span class="n">queue</span> <span class="k">if</span> <span class="nb">set</span><span class="p">(</span><span class="n">coll_path</span><span class="p">)</span><span class="o">.</span><span class="n">issubset</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">q_path</span> <span class="o">+</span> <span class="p">[</span><span class="n">q_node</span><span class="p">]))</span> <span class="ow">is</span> <span class="kc">False</span><span class="p">]</span>
                                    <span class="n">non_valid_subsets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">coll_path</span><span class="p">)</span>
                                <span class="k">else</span><span class="p">:</span>
                                    <span class="k">return</span> <span class="kc">False</span>

                            <span class="k">elif</span> <span class="n">condition</span> <span class="o">==</span> <span class="s1">&#39;II&#39;</span><span class="p">:</span>
                                <span class="k">return</span> <span class="kc">True</span>
                                <span class="c1"># yield coll_path</span>
 
        <span class="k">if</span> <span class="n">condition</span> <span class="o">==</span> <span class="s1">&#39;I&#39;</span><span class="p">:</span>
            <span class="k">return</span> <span class="kc">True</span>
        <span class="k">elif</span> <span class="n">condition</span> <span class="o">==</span> <span class="s1">&#39;II&#39;</span><span class="p">:</span>
            <span class="k">return</span> <span class="kc">False</span>
        <span class="c1"># return collider_paths</span>


<div class="viewcode-block" id="CausalEffects.check_optimality"><a class="viewcode-back" href="../../index.html#tigramite.causal_effects.CausalEffects.check_optimality">[docs]</a>    <span class="k">def</span> <span class="nf">check_optimality</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Check whether optimal adjustment set exists according to Thm. 3 in Runge NeurIPS 2021.</span>

<span class="sd">        Returns</span>
<span class="sd">        -------</span>
<span class="sd">        optimality : bool</span>
<span class="sd">            Returns True if an optimal adjustment set exists, otherwise False.</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="c1"># Cond. 0: Exactly one valid adjustment set exists</span>
        <span class="n">cond_0</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_all_valid_adjustment_sets</span><span class="p">(</span><span class="n">check_one_set_exists</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span>

        <span class="c1">#</span>
        <span class="c1"># Cond. I</span>
        <span class="c1">#</span>
        <span class="n">parents</span><span class="p">,</span> <span class="n">colliders</span><span class="p">,</span> <span class="n">collider_parents</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_optimal_set</span><span class="p">(</span><span class="n">return_separate_sets</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
        <span class="n">Oset</span> <span class="o">=</span> <span class="n">parents</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">colliders</span><span class="p">)</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">collider_parents</span><span class="p">)</span>
        <span class="n">n_nodes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_all_spouses</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">M</span><span class="p">)</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">colliders</span><span class="p">))</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">forbidden_nodes</span> <span class="o">-</span> <span class="n">Oset</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">S</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">Y</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">M</span> <span class="o">-</span> <span class="n">colliders</span>

        <span class="k">if</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">):</span>
            <span class="c1"># # (1) There are no spouses N ∈ sp(YMC) \ (forbOS)</span>
            <span class="n">cond_I</span> <span class="o">=</span> <span class="kc">True</span>
        <span class="k">else</span><span class="p">:</span>
            
            <span class="c1"># (2) For all N ∈ N and all its collider paths i it holds that </span>
            <span class="c1"># OπiN does not block all non-causal paths from X to Y</span>
            <span class="c1"># cond_I = True</span>
            <span class="n">cond_I</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_collider_paths_optimality</span><span class="p">(</span>
                <span class="n">source_nodes</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="n">n_nodes</span><span class="p">),</span> <span class="n">target_nodes</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">M</span><span class="p">)),</span>
                <span class="n">condition</span><span class="o">=</span><span class="s1">&#39;I&#39;</span><span class="p">,</span> 
                <span class="n">inside_set</span><span class="o">=</span><span class="n">Oset</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">S</span><span class="p">),</span> <span class="n">start_with_tail_or_head</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
                <span class="p">)</span>

        <span class="c1">#</span>
        <span class="c1"># Cond. II</span>
        <span class="c1">#</span>
        <span class="n">e_nodes</span> <span class="o">=</span> <span class="n">Oset</span><span class="o">.</span><span class="n">difference</span><span class="p">(</span><span class="n">parents</span><span class="p">)</span>
        <span class="n">cond_II</span> <span class="o">=</span> <span class="kc">True</span>
        <span class="k">for</span> <span class="n">E</span> <span class="ow">in</span> <span class="n">e_nodes</span><span class="p">:</span>
            <span class="n">Oset_minusE</span> <span class="o">=</span> <span class="n">Oset</span><span class="o">.</span><span class="n">difference</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">E</span><span class="p">]))</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_path</span><span class="p">(</span><span class="c1">#graph=self.graph, </span>
                <span class="n">start</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">),</span> <span class="n">end</span><span class="o">=</span><span class="p">[</span><span class="n">E</span><span class="p">],</span> 
                                <span class="n">conditions</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">S</span><span class="p">)</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="n">Oset_minusE</span><span class="p">)):</span>
                   
                <span class="n">cond_II</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_collider_paths_optimality</span><span class="p">(</span>
                    <span class="n">target_nodes</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">M</span><span class="p">),</span> 
                    <span class="n">source_nodes</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">E</span><span class="p">])),</span>
                    <span class="n">condition</span><span class="o">=</span><span class="s1">&#39;II&#39;</span><span class="p">,</span> 
                    <span class="n">inside_set</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="n">Oset</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">S</span><span class="p">)),</span>
                    <span class="n">start_with_tail_or_head</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span>
               
                <span class="k">if</span> <span class="n">cond_II</span> <span class="ow">is</span> <span class="kc">False</span><span class="p">:</span>
                    <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
                        <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Non-optimal due to E = &quot;</span><span class="p">,</span> <span class="n">E</span><span class="p">)</span>
                    <span class="k">break</span>

        <span class="n">optimality</span> <span class="o">=</span> <span class="p">(</span><span class="n">cond_0</span> <span class="ow">or</span> <span class="p">(</span><span class="n">cond_I</span> <span class="ow">and</span> <span class="n">cond_II</span><span class="p">))</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Optimality = </span><span class="si">%s</span><span class="s2"> with cond_0 = </span><span class="si">%s</span><span class="s2">, cond_I = </span><span class="si">%s</span><span class="s2">, cond_II = </span><span class="si">%s</span><span class="s2">&quot;</span>
                    <span class="o">%</span>  <span class="p">(</span><span class="n">optimality</span><span class="p">,</span> <span class="n">cond_0</span><span class="p">,</span> <span class="n">cond_I</span><span class="p">,</span> <span class="n">cond_II</span><span class="p">))</span>
        <span class="k">return</span> <span class="n">optimality</span></div>

    <span class="k">def</span> <span class="nf">_check_validity</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">Z</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Checks whether Z is a valid adjustment set.&quot;&quot;&quot;</span>

        <span class="c1"># causal_children = list(self.M.union(self.Y))</span>
        <span class="n">backdoor_path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_path</span><span class="p">(</span><span class="c1">#graph=self.graph, </span>
            <span class="n">start</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">),</span> <span class="n">end</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="p">),</span> 
                            <span class="n">conditions</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="n">Z</span><span class="p">),</span> 
                            <span class="c1"># causal_children=causal_children,</span>
                            <span class="n">path_type</span> <span class="o">=</span> <span class="s1">&#39;non_causal&#39;</span><span class="p">)</span>

        <span class="k">if</span> <span class="n">backdoor_path</span><span class="p">:</span>
            <span class="k">return</span> <span class="kc">False</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">return</span> <span class="kc">True</span>
    
    <span class="k">def</span> <span class="nf">_get_adjust_set</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> 
        <span class="n">minimize</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
        <span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Returns Adjust-set.</span>
<span class="sd">        </span>
<span class="sd">        See van der Zander, B.; Liśkiewicz, M. &amp; Textor, J.</span>
<span class="sd">        Separators and adjustment sets in causal graphs: Complete </span>
<span class="sd">        criteria and an algorithmic framework </span>
<span class="sd">        Artificial Intelligence, Elsevier, 2019, 270, 1-40</span>

<span class="sd">        &quot;&quot;&quot;</span>

        <span class="n">vancs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">vancs</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>

        <span class="k">if</span> <span class="n">minimize</span><span class="p">:</span>
            <span class="c1"># Get removable nodes by computing minimal valid set from Z</span>
            <span class="k">if</span> <span class="n">minimize</span> <span class="o">==</span> <span class="s1">&#39;keep_parentsYM&#39;</span><span class="p">:</span>
                <span class="n">minimize_nodes</span> <span class="o">=</span> <span class="n">vancs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_all_parents</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">M</span><span class="p">)))</span>

            <span class="k">else</span><span class="p">:</span>
                <span class="n">minimize_nodes</span> <span class="o">=</span> <span class="n">vancs</span>

            <span class="c1"># Zprime2 = Zprime</span>
            <span class="c1"># First remove all nodes that have no unique path to X given Oset</span>
            <span class="k">for</span> <span class="n">node</span> <span class="ow">in</span> <span class="n">minimize_nodes</span><span class="p">:</span>
                <span class="c1"># path = self.oracle.check_shortest_path(X=X, Y=[node], </span>
                <span class="c1">#     Z=list(vancs - set([node])), </span>
                <span class="c1">#     max_lag=None, </span>
                <span class="c1">#     starts_with=None, #&#39;arrowhead&#39;, </span>
                <span class="c1">#     forbidden_nodes=None, #list(Zprime - set([node])), </span>
                <span class="c1">#     return_path=False)</span>
                <span class="n">path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_path</span><span class="p">(</span><span class="c1">#graph=self.graph, </span>
                    <span class="n">start</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">,</span> <span class="n">end</span><span class="o">=</span><span class="p">[</span><span class="n">node</span><span class="p">],</span> 
                    <span class="n">conditions</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="n">vancs</span> <span class="o">-</span> <span class="nb">set</span><span class="p">([</span><span class="n">node</span><span class="p">])),</span> 
                     <span class="p">)</span>
  
                <span class="k">if</span> <span class="n">path</span> <span class="ow">is</span> <span class="kc">False</span><span class="p">:</span>
                    <span class="n">vancs</span> <span class="o">=</span> <span class="n">vancs</span> <span class="o">-</span> <span class="nb">set</span><span class="p">([</span><span class="n">node</span><span class="p">])</span>

            <span class="k">if</span> <span class="n">minimize</span> <span class="o">==</span> <span class="s1">&#39;keep_parentsYM&#39;</span><span class="p">:</span>
                <span class="n">minimize_nodes</span> <span class="o">=</span> <span class="n">vancs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_all_parents</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">M</span><span class="p">)))</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="n">minimize_nodes</span> <span class="o">=</span> <span class="n">vancs</span>

            <span class="c1"># print(Zprime2) </span>
            <span class="c1"># Next remove all nodes that have no unique path to Y given Oset_min</span>
            <span class="c1"># Z = Zprime2</span>
            <span class="k">for</span> <span class="n">node</span> <span class="ow">in</span> <span class="n">minimize_nodes</span><span class="p">:</span>

                <span class="n">path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_path</span><span class="p">(</span><span class="c1">#graph=self.graph, </span>
                    <span class="n">start</span><span class="o">=</span><span class="p">[</span><span class="n">node</span><span class="p">],</span> <span class="n">end</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="p">,</span> 
                    <span class="n">conditions</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="n">vancs</span> <span class="o">-</span> <span class="nb">set</span><span class="p">([</span><span class="n">node</span><span class="p">]))</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">),</span>
                    <span class="p">)</span>

                <span class="k">if</span> <span class="n">path</span> <span class="ow">is</span> <span class="kc">False</span><span class="p">:</span>
                   <span class="n">vancs</span> <span class="o">=</span> <span class="n">vancs</span> <span class="o">-</span> <span class="nb">set</span><span class="p">([</span><span class="n">node</span><span class="p">])</span>  

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_validity</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">vancs</span><span class="p">))</span> <span class="ow">is</span> <span class="kc">False</span><span class="p">:</span>
            <span class="k">return</span> <span class="kc">False</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="n">vancs</span><span class="p">)</span>


    <span class="k">def</span> <span class="nf">_get_all_valid_adjustment_sets</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> 
        <span class="n">check_one_set_exists</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">yield_index</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Constructs all valid adjustment sets or just checks whether one exists.</span>
<span class="sd">        </span>
<span class="sd">        See van der Zander, B.; Liśkiewicz, M. &amp; Textor, J.</span>
<span class="sd">        Separators and adjustment sets in causal graphs: Complete </span>
<span class="sd">        criteria and an algorithmic framework </span>
<span class="sd">        Artificial Intelligence, Elsevier, 2019, 270, 1-40</span>

<span class="sd">        &quot;&quot;&quot;</span>

        <span class="n">cond_set</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">S</span><span class="p">)</span>
        <span class="n">all_vars</span> <span class="o">=</span> <span class="p">[(</span><span class="n">i</span><span class="p">,</span> <span class="o">-</span><span class="n">tau</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">N</span><span class="p">)</span>
                    <span class="k">for</span> <span class="n">tau</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span>

        <span class="n">all_vars_set</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">all_vars</span><span class="p">)</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">forbidden_nodes</span>


        <span class="k">def</span> <span class="nf">find_sep</span><span class="p">(</span><span class="n">I</span><span class="p">,</span> <span class="n">R</span><span class="p">):</span>
            <span class="n">Rprime</span> <span class="o">=</span> <span class="n">R</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">X</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">Y</span>
            <span class="c1"># TODO: anteriors and NOT ancestors where</span>
            <span class="c1"># anteriors include --- links in causal paths</span>
            <span class="c1"># print(I)</span>
            <span class="n">XYI</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="p">)</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">I</span><span class="p">))</span>
            <span class="c1"># print(XYI)</span>
            <span class="n">ancs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_ancestors</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">XYI</span><span class="p">))</span>
            <span class="n">Z</span> <span class="o">=</span> <span class="n">ancs</span><span class="o">.</span><span class="n">intersection</span><span class="p">(</span><span class="n">Rprime</span><span class="p">)</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_validity</span><span class="p">(</span><span class="n">Z</span><span class="p">)</span> <span class="ow">is</span> <span class="kc">False</span><span class="p">:</span>
                <span class="k">return</span> <span class="kc">False</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="k">return</span> <span class="n">Z</span>


        <span class="k">def</span> <span class="nf">list_sep</span><span class="p">(</span><span class="n">I</span><span class="p">,</span> <span class="n">R</span><span class="p">):</span>
            <span class="c1"># print(find_sep(X, Y, I, R))</span>
            <span class="k">if</span> <span class="n">find_sep</span><span class="p">(</span><span class="n">I</span><span class="p">,</span> <span class="n">R</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">False</span><span class="p">:</span>
                <span class="c1"># print(I,R)</span>
                <span class="k">if</span> <span class="n">I</span> <span class="o">==</span> <span class="n">R</span><span class="p">:</span> 
                    <span class="c1"># print(&#39;---&gt;&#39;, I)</span>
                    <span class="k">yield</span> <span class="n">I</span>
                <span class="k">else</span><span class="p">:</span>
                    <span class="c1"># Pick arbitrary node from R-I</span>
                    <span class="n">RminusI</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">R</span> <span class="o">-</span> <span class="n">I</span><span class="p">)</span>
                    <span class="c1"># print(R, I, RminusI)</span>
                    <span class="n">v</span> <span class="o">=</span> <span class="n">RminusI</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
                    <span class="c1"># print(&quot;here &quot;, X, Y, I.union(set([v])), R)</span>
                    <span class="k">yield from</span> <span class="n">list_sep</span><span class="p">(</span><span class="n">I</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">v</span><span class="p">])),</span> <span class="n">R</span><span class="p">)</span>
                    <span class="k">yield from</span> <span class="n">list_sep</span><span class="p">(</span><span class="n">I</span><span class="p">,</span> <span class="n">R</span> <span class="o">-</span> <span class="nb">set</span><span class="p">([</span><span class="n">v</span><span class="p">]))</span>

        <span class="c1"># print(&quot;all &quot;, X, Y, cond_set, all_vars_set)</span>
        <span class="n">all_sets</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="n">I</span> <span class="o">=</span> <span class="n">cond_set</span>
        <span class="n">R</span> <span class="o">=</span> <span class="n">all_vars_set</span>
        <span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">valid_set</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">list_sep</span><span class="p">(</span><span class="n">I</span><span class="p">,</span> <span class="n">R</span><span class="p">)):</span>
            <span class="c1"># print(valid_set)</span>
            <span class="n">all_sets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">valid_set</span><span class="p">))</span>
            <span class="k">if</span> <span class="n">check_one_set_exists</span> <span class="ow">and</span> <span class="n">index</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="k">break</span>

            <span class="k">if</span> <span class="n">yield_index</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">index</span> <span class="o">==</span> <span class="n">yield_index</span><span class="p">:</span>
                <span class="k">return</span> <span class="n">valid_set</span>

        <span class="k">if</span> <span class="n">yield_index</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="k">return</span> <span class="kc">None</span>

        <span class="k">if</span> <span class="n">check_one_set_exists</span><span class="p">:</span>
            <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">all_sets</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
                <span class="k">return</span> <span class="kc">True</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="k">return</span> <span class="kc">False</span>

        <span class="k">return</span> <span class="n">all_sets</span>


    <span class="k">def</span> <span class="nf">_get_causal_paths</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">source_nodes</span><span class="p">,</span> <span class="n">target_nodes</span><span class="p">,</span>
        <span class="n">mediators</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
        <span class="n">mediated_through</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
        <span class="n">proper_paths</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
        <span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Returns causal paths via depth-first search.</span>

<span class="sd">        Allows to restrict paths through mediated_through.</span>

<span class="sd">        &quot;&quot;&quot;</span>

        <span class="n">source_nodes</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">source_nodes</span><span class="p">)</span>
        <span class="n">target_nodes</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">target_nodes</span><span class="p">)</span>

        <span class="k">if</span> <span class="n">mediators</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">mediators</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">mediators</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">mediators</span><span class="p">)</span>

        <span class="k">if</span> <span class="n">mediated_through</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">mediated_through</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="n">mediated_through</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">mediated_through</span><span class="p">)</span>

        <span class="k">if</span> <span class="n">proper_paths</span><span class="p">:</span>
             <span class="n">inside_set</span> <span class="o">=</span> <span class="n">mediators</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">target_nodes</span><span class="p">)</span> <span class="o">-</span> <span class="n">source_nodes</span>
        <span class="k">else</span><span class="p">:</span>
             <span class="n">inside_set</span> <span class="o">=</span> <span class="n">mediators</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">target_nodes</span><span class="p">)</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="n">source_nodes</span><span class="p">)</span>

        <span class="n">all_causal_paths</span> <span class="o">=</span> <span class="p">{}</span>         
        <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">source_nodes</span><span class="p">:</span>
            <span class="n">all_causal_paths</span><span class="p">[</span><span class="n">w</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
            <span class="k">for</span> <span class="n">z</span> <span class="ow">in</span> <span class="n">target_nodes</span><span class="p">:</span>
                <span class="n">all_causal_paths</span><span class="p">[</span><span class="n">w</span><span class="p">][</span><span class="n">z</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>

        <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">source_nodes</span><span class="p">:</span>
            
            <span class="n">causal_path</span> <span class="o">=</span> <span class="p">[]</span>
            <span class="n">queue</span> <span class="o">=</span> <span class="p">[(</span><span class="n">w</span><span class="p">,</span> <span class="n">causal_path</span><span class="p">)]</span>

            <span class="k">while</span> <span class="n">queue</span><span class="p">:</span>

                <span class="n">varlag</span><span class="p">,</span> <span class="n">causal_path</span> <span class="o">=</span> <span class="n">queue</span><span class="o">.</span><span class="n">pop</span><span class="p">()</span>
                <span class="n">causal_path</span> <span class="o">=</span> <span class="n">causal_path</span> <span class="o">+</span> <span class="p">[</span><span class="n">varlag</span><span class="p">]</span>
                <span class="n">suitable_nodes</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_children</span><span class="p">(</span><span class="n">varlag</span><span class="p">)</span>
                    <span class="p">)</span><span class="o">.</span><span class="n">intersection</span><span class="p">(</span><span class="n">inside_set</span><span class="p">)</span>
                <span class="k">for</span> <span class="n">node</span> <span class="ow">in</span> <span class="n">suitable_nodes</span><span class="p">:</span>
                    <span class="n">i</span><span class="p">,</span> <span class="n">tau</span> <span class="o">=</span> <span class="n">node</span>
                    <span class="k">if</span> <span class="p">((</span><span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">&lt;=</span> <span class="n">tau</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">)</span> <span class="c1"># or self.ignore_time_bounds)</span>
                        <span class="ow">and</span> <span class="n">node</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">causal_path</span><span class="p">):</span>

                        <span class="n">queue</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">node</span><span class="p">,</span> <span class="n">causal_path</span><span class="p">))</span> 
 
                        <span class="k">if</span> <span class="n">node</span> <span class="ow">in</span> <span class="n">target_nodes</span><span class="p">:</span>  
                            <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">mediated_through</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">causal_path</span><span class="p">)</span><span class="o">.</span><span class="n">intersection</span><span class="p">(</span><span class="n">mediated_through</span><span class="p">))</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
                                <span class="k">continue</span>
                            <span class="k">else</span><span class="p">:</span>
                                <span class="n">all_causal_paths</span><span class="p">[</span><span class="n">w</span><span class="p">][</span><span class="n">node</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">causal_path</span> <span class="o">+</span> <span class="p">[</span><span class="n">node</span><span class="p">])</span> 

        <span class="k">return</span> <span class="n">all_causal_paths</span>

<div class="viewcode-block" id="CausalEffects.fit_total_effect"><a class="viewcode-back" href="../../index.html#tigramite.causal_effects.CausalEffects.fit_total_effect">[docs]</a>    <span class="k">def</span> <span class="nf">fit_total_effect</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
        <span class="n">dataframe</span><span class="p">,</span> 
        <span class="n">estimator</span><span class="p">,</span>
        <span class="n">adjustment_set</span><span class="o">=</span><span class="s1">&#39;optimal&#39;</span><span class="p">,</span>
        <span class="n">conditional_estimator</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>  
        <span class="n">data_transform</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
        <span class="n">mask_type</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
        <span class="n">ignore_identifiability</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
        <span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Returns a fitted model for the total causal effect of X on Y </span>
<span class="sd">           conditional on S.</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        dataframe : data object</span>
<span class="sd">            Tigramite dataframe object. It must have the attributes dataframe.values</span>
<span class="sd">            yielding a numpy array of shape (observations T, variables N) and</span>
<span class="sd">            optionally a mask of the same shape and a missing values flag.</span>
<span class="sd">        estimator : sklearn model object</span>
<span class="sd">            For example, sklearn.linear_model.LinearRegression() for a linear</span>
<span class="sd">            regression model.</span>
<span class="sd">        adjustment_set : str or list of tuples</span>
<span class="sd">            If &#39;optimal&#39; the Oset is used, if &#39;minimized_optimal&#39; the minimized Oset,</span>
<span class="sd">            and if &#39;colliders_minimized_optimal&#39;, the colliders-minimized Oset.</span>
<span class="sd">            If a list of tuples is passed, this set is used.</span>
<span class="sd">        conditional_estimator : sklearn model object, optional (default: None)</span>
<span class="sd">            Used to fit conditional causal effects in nested regression. </span>
<span class="sd">            If None, the same model as for estimator is used.</span>
<span class="sd">        data_transform : sklearn preprocessing object, optional (default: None)</span>
<span class="sd">            Used to transform data prior to fitting. For example,</span>
<span class="sd">            sklearn.preprocessing.StandardScaler for simple standardization. The</span>
<span class="sd">            fitted parameters are stored.</span>
<span class="sd">        mask_type : {None, &#39;y&#39;,&#39;x&#39;,&#39;z&#39;,&#39;xy&#39;,&#39;xz&#39;,&#39;yz&#39;,&#39;xyz&#39;}</span>
<span class="sd">            Masking mode: Indicators for which variables in the dependence</span>
<span class="sd">            measure I(X; Y | Z) the samples should be masked. If None, the mask</span>
<span class="sd">            is not used. Explained in tutorial on masking and missing values.</span>
<span class="sd">        ignore_identifiability : bool</span>
<span class="sd">            Only applies to adjustment sets supplied by user. Ignores if that </span>
<span class="sd">            set leads to a non-identifiable effect.</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">no_causal_path</span><span class="p">:</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;No causal path from X to Y exists.&quot;</span><span class="p">)</span>
            <span class="k">return</span> <span class="bp">self</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">dataframe</span> <span class="o">=</span> <span class="n">dataframe</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">conditional_estimator</span> <span class="o">=</span> <span class="n">conditional_estimator</span>

        <span class="c1"># if self.dataframe.has_vector_data:</span>
        <span class="c1">#     raise ValueError(&quot;vector_vars in DataFrame cannot be used together with CausalEffects!&quot;</span>
        <span class="c1">#                      &quot; You can estimate vector-valued effects by using multivariate X, Y, S.&quot;</span>
        <span class="c1">#                      &quot; Note, however, that this requires assuming a graph at the level &quot;</span>
        <span class="c1">#                      &quot;of the components of X, Y, S, ...&quot;)</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">N</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataframe</span><span class="o">.</span><span class="n">N</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Dataset dimensions inconsistent with number of variables in graph.&quot;</span><span class="p">)</span>

        <span class="k">if</span> <span class="n">adjustment_set</span> <span class="o">==</span> <span class="s1">&#39;optimal&#39;</span><span class="p">:</span>
            <span class="c1"># Check optimality and use either optimal or colliders_only set</span>
            <span class="n">adjustment_set</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_optimal_set</span><span class="p">()</span>
        <span class="k">elif</span> <span class="n">adjustment_set</span> <span class="o">==</span> <span class="s1">&#39;colliders_minimized_optimal&#39;</span><span class="p">:</span>
            <span class="n">adjustment_set</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_optimal_set</span><span class="p">(</span><span class="n">minimize</span><span class="o">=</span><span class="s1">&#39;colliders_only&#39;</span><span class="p">)</span>
        <span class="k">elif</span> <span class="n">adjustment_set</span> <span class="o">==</span> <span class="s1">&#39;minimized_optimal&#39;</span><span class="p">:</span>
            <span class="n">adjustment_set</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_optimal_set</span><span class="p">(</span><span class="n">minimize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">if</span> <span class="n">ignore_identifiability</span> <span class="ow">is</span> <span class="kc">False</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_validity</span><span class="p">(</span><span class="n">adjustment_set</span><span class="p">)</span> <span class="ow">is</span> <span class="kc">False</span><span class="p">:</span>
                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Chosen adjustment_set is not valid.&quot;</span><span class="p">)</span>

        <span class="k">if</span> <span class="n">adjustment_set</span> <span class="ow">is</span> <span class="kc">False</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Causal effect not identifiable via adjustment.&quot;</span><span class="p">)</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">adjustment_set</span> <span class="o">=</span> <span class="n">adjustment_set</span>

        <span class="c1"># Fit model of Y on X and Z (and conditions)</span>
        <span class="c1"># Build the model</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">Models</span><span class="p">(</span>
                        <span class="n">dataframe</span><span class="o">=</span><span class="n">dataframe</span><span class="p">,</span>
                        <span class="n">model</span><span class="o">=</span><span class="n">estimator</span><span class="p">,</span>
                        <span class="n">conditional_model</span><span class="o">=</span><span class="n">conditional_estimator</span><span class="p">,</span>
                        <span class="n">data_transform</span><span class="o">=</span><span class="n">data_transform</span><span class="p">,</span>
                        <span class="n">mask_type</span><span class="o">=</span><span class="n">mask_type</span><span class="p">,</span>
                        <span class="n">verbosity</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span><span class="p">)</span>      

        <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">get_general_fitted_model</span><span class="p">(</span>
                <span class="n">Y</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">listY</span><span class="p">,</span> <span class="n">X</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">listX</span><span class="p">,</span> <span class="n">Z</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">adjustment_set</span><span class="p">),</span>
                <span class="n">conditions</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">listS</span><span class="p">,</span>
                <span class="n">tau_max</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span><span class="p">,</span>
                <span class="n">cut_off</span><span class="o">=</span><span class="s1">&#39;tau_max&#39;</span><span class="p">,</span>
                <span class="n">return_data</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>

        <span class="k">return</span> <span class="bp">self</span></div>

<div class="viewcode-block" id="CausalEffects.predict_total_effect"><a class="viewcode-back" href="../../index.html#tigramite.causal_effects.CausalEffects.predict_total_effect">[docs]</a>    <span class="k">def</span> <span class="nf">predict_total_effect</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> 
        <span class="n">intervention_data</span><span class="p">,</span> 
        <span class="n">conditions_data</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
        <span class="n">pred_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
        <span class="n">return_further_pred_results</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
        <span class="n">aggregation_func</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">,</span>
        <span class="n">transform_interventions_and_prediction</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
        <span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Predict effect of intervention with fitted model.</span>

<span class="sd">        Uses the model.predict() function of the sklearn model.</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        intervention_data : numpy array</span>
<span class="sd">            Numpy array of shape (time, len(X)) that contains the do(X) values.</span>
<span class="sd">        conditions_data : data object, optional</span>
<span class="sd">            Numpy array of shape (time, len(S)) that contains the S=s values.</span>
<span class="sd">        pred_params : dict, optional</span>
<span class="sd">            Optional parameters passed on to sklearn prediction function.</span>
<span class="sd">        return_further_pred_results : bool, optional (default: False)</span>
<span class="sd">            In case the predictor class returns more than just the expected value,</span>
<span class="sd">            the entire results can be returned.</span>
<span class="sd">        aggregation_func : callable</span>
<span class="sd">            Callable applied to output of &#39;predict&#39;. Default is &#39;np.mean&#39;.</span>
<span class="sd">        transform_interventions_and_prediction : bool (default: False)</span>
<span class="sd">            Whether to perform the inverse data_transform on prediction results.</span>
<span class="sd">        </span>
<span class="sd">        Returns</span>
<span class="sd">        -------</span>
<span class="sd">        Results from prediction: an array of shape  (time, len(Y)).</span>
<span class="sd">        If estimate_confidence = True, then a tuple is returned.</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="k">def</span> <span class="nf">get_vectorized_length</span><span class="p">(</span><span class="n">W</span><span class="p">):</span>
            <span class="k">return</span> <span class="nb">sum</span><span class="p">([</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataframe</span><span class="o">.</span><span class="n">vector_vars</span><span class="p">[</span><span class="n">w</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span> <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">W</span><span class="p">])</span>

        <span class="c1"># lenX = len(self.listX)</span>
        <span class="c1"># lenS = len(self.listS)</span>

        <span class="n">lenX</span> <span class="o">=</span> <span class="n">get_vectorized_length</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">listX</span><span class="p">)</span>
        <span class="n">lenS</span> <span class="o">=</span> <span class="n">get_vectorized_length</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">listS</span><span class="p">)</span>

        <span class="k">if</span> <span class="n">intervention_data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">!=</span> <span class="n">lenX</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;intervention_data.shape[1] must be len(X).&quot;</span><span class="p">)</span>

        <span class="k">if</span> <span class="n">conditions_data</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">lenS</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="k">if</span> <span class="n">conditions_data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">!=</span> <span class="n">lenS</span><span class="p">:</span>
                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;conditions_data.shape[1] must be len(S).&quot;</span><span class="p">)</span>
            <span class="k">if</span> <span class="n">conditions_data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="n">intervention_data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]:</span>
                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;conditions_data.shape[0] must match intervention_data.shape[0].&quot;</span><span class="p">)</span>
        <span class="k">elif</span> <span class="n">conditions_data</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">lenS</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;conditions_data specified, but S=None or empty.&quot;</span><span class="p">)</span>
        <span class="k">elif</span> <span class="n">conditions_data</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">lenS</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;S specified, but conditions_data is None.&quot;</span><span class="p">)</span>


        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">no_causal_path</span><span class="p">:</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;No causal path from X to Y exists.&quot;</span><span class="p">)</span>
            <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">intervention_data</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">listY</span><span class="p">)))</span>

        <span class="n">effect</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">get_general_prediction</span><span class="p">(</span>
            <span class="n">intervention_data</span><span class="o">=</span><span class="n">intervention_data</span><span class="p">,</span>
            <span class="n">conditions_data</span><span class="o">=</span><span class="n">conditions_data</span><span class="p">,</span>
            <span class="n">pred_params</span><span class="o">=</span><span class="n">pred_params</span><span class="p">,</span>
            <span class="n">return_further_pred_results</span><span class="o">=</span><span class="n">return_further_pred_results</span><span class="p">,</span>
            <span class="n">transform_interventions_and_prediction</span><span class="o">=</span><span class="n">transform_interventions_and_prediction</span><span class="p">,</span>
            <span class="n">aggregation_func</span><span class="o">=</span><span class="n">aggregation_func</span><span class="p">,)</span> 

        <span class="k">return</span> <span class="n">effect</span></div>

<div class="viewcode-block" id="CausalEffects.fit_wright_effect"><a class="viewcode-back" href="../../index.html#tigramite.causal_effects.CausalEffects.fit_wright_effect">[docs]</a>    <span class="k">def</span> <span class="nf">fit_wright_effect</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
        <span class="n">dataframe</span><span class="p">,</span> 
        <span class="n">mediation</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
        <span class="n">method</span><span class="o">=</span><span class="s1">&#39;parents&#39;</span><span class="p">,</span>
        <span class="n">links_coeffs</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>  
        <span class="n">data_transform</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
        <span class="n">mask_type</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
        <span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Returns a fitted model for the total or mediated causal effect of X on Y </span>
<span class="sd">           potentially through mediator variables.</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        dataframe : data object</span>
<span class="sd">            Tigramite dataframe object. It must have the attributes dataframe.values</span>
<span class="sd">            yielding a numpy array of shape (observations T, variables N) and</span>
<span class="sd">            optionally a mask of the same shape and a missing values flag.</span>
<span class="sd">        mediation : None, &#39;direct&#39;, or list of tuples</span>
<span class="sd">            If None, total effect is estimated, if &#39;direct&#39; then only the direct effect is estimated,</span>
<span class="sd">            else only those causal paths are considerd that pass at least through one of these mediator nodes.</span>
<span class="sd">        method : {&#39;parents&#39;, &#39;links_coeffs&#39;, &#39;optimal&#39;}</span>
<span class="sd">            Method to use for estimating Wright&#39;s path coefficients. If &#39;optimal&#39;, </span>
<span class="sd">            the Oset is used, if &#39;links_coeffs&#39;, the coefficients in links_coeffs are used,</span>
<span class="sd">            if &#39;parents&#39;, the parents are used (only valid for DAGs).</span>
<span class="sd">        links_coeffs : dict</span>
<span class="sd">            Only used if method = &#39;links_coeffs&#39;.</span>
<span class="sd">            Dictionary of format: {0:[((i, -tau), coeff),...], 1:[...],</span>
<span class="sd">            ...} for all variables where i must be in [0..N-1] and tau &gt;= 0 with</span>
<span class="sd">            number of variables N. coeff must be a float.</span>
<span class="sd">        data_transform : None</span>
<span class="sd">            Not implemented for Wright estimator. Complicated for missing samples.</span>
<span class="sd">        mask_type : {None, &#39;y&#39;,&#39;x&#39;,&#39;z&#39;,&#39;xy&#39;,&#39;xz&#39;,&#39;yz&#39;,&#39;xyz&#39;}</span>
<span class="sd">            Masking mode: Indicators for which variables in the dependence</span>
<span class="sd">            measure I(X; Y | Z) the samples should be masked. If None, the mask</span>
<span class="sd">            is not used. Explained in tutorial on masking and missing values.</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">no_causal_path</span><span class="p">:</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;No causal path from X to Y exists.&quot;</span><span class="p">)</span>
            <span class="k">return</span> <span class="bp">self</span>

        <span class="k">if</span> <span class="n">data_transform</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;data_transform not implemented for Wright estimator.&quot;</span>
                             <span class="s2">&quot; You can preprocess data yourself beforehand.&quot;</span><span class="p">)</span>

        <span class="kn">import</span> <span class="nn">sklearn.linear_model</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">dataframe</span> <span class="o">=</span> <span class="n">dataframe</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataframe</span><span class="o">.</span><span class="n">has_vector_data</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;vector_vars in DataFrame cannot be used together with Wright method!&quot;</span>
                             <span class="s2">&quot; You can either 1) estimate vector-valued effects by using multivariate (X, Y, S)&quot;</span>
                             <span class="s2">&quot; together with assuming a graph at the level of the components of (X, Y, S), &quot;</span>
                             <span class="s2">&quot; or 2) use vector_vars together with fit_total_effect and an estimator&quot;</span>
                             <span class="s2">&quot; that supports multiple outputs.&quot;</span><span class="p">)</span>

        <span class="n">estimator</span> <span class="o">=</span> <span class="n">sklearn</span><span class="o">.</span><span class="n">linear_model</span><span class="o">.</span><span class="n">LinearRegression</span><span class="p">()</span>

        <span class="c1"># Fit model of Y on X and Z (and conditions)</span>
        <span class="c1"># Build the model</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">Models</span><span class="p">(</span>
                        <span class="n">dataframe</span><span class="o">=</span><span class="n">dataframe</span><span class="p">,</span>
                        <span class="n">model</span><span class="o">=</span><span class="n">estimator</span><span class="p">,</span>
                        <span class="n">data_transform</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="c1">#data_transform,</span>
                        <span class="n">mask_type</span><span class="o">=</span><span class="n">mask_type</span><span class="p">,</span>
                        <span class="n">verbosity</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span><span class="p">)</span>

        <span class="n">mediators</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">M</span>  <span class="c1"># self.get_mediators(start=self.X, end=self.Y)</span>

        <span class="k">if</span> <span class="n">mediation</span> <span class="o">==</span> <span class="s1">&#39;direct&#39;</span><span class="p">:</span>
            <span class="n">causal_paths</span> <span class="o">=</span> <span class="p">{}</span>         
            <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">:</span>
                <span class="n">causal_paths</span><span class="p">[</span><span class="n">w</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
                <span class="k">for</span> <span class="n">z</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="p">:</span>
                    <span class="k">if</span> <span class="n">w</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_parents</span><span class="p">(</span><span class="n">z</span><span class="p">):</span>
                        <span class="n">causal_paths</span><span class="p">[</span><span class="n">w</span><span class="p">][</span><span class="n">z</span><span class="p">]</span> <span class="o">=</span> <span class="p">[[</span><span class="n">w</span><span class="p">,</span> <span class="n">z</span><span class="p">]]</span>
                    <span class="k">else</span><span class="p">:</span>
                        <span class="n">causal_paths</span><span class="p">[</span><span class="n">w</span><span class="p">][</span><span class="n">z</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">causal_paths</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_causal_paths</span><span class="p">(</span><span class="n">source_nodes</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">,</span> 
                <span class="n">target_nodes</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="p">,</span> <span class="n">mediators</span><span class="o">=</span><span class="n">mediators</span><span class="p">,</span> 
                <span class="n">mediated_through</span><span class="o">=</span><span class="n">mediation</span><span class="p">,</span> <span class="n">proper_paths</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>

        <span class="k">if</span> <span class="n">method</span> <span class="o">==</span> <span class="s1">&#39;links_coeffs&#39;</span><span class="p">:</span>
            <span class="n">coeffs</span> <span class="o">=</span> <span class="p">{}</span>
            <span class="n">max_lag</span> <span class="o">=</span> <span class="mi">0</span>
            <span class="k">for</span> <span class="n">medy</span> <span class="ow">in</span> <span class="p">[</span><span class="n">med</span> <span class="k">for</span> <span class="n">med</span> <span class="ow">in</span> <span class="n">mediators</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="n">y</span> <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">listY</span><span class="p">]:</span>
                <span class="n">coeffs</span><span class="p">[</span><span class="n">medy</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
                <span class="n">j</span><span class="p">,</span> <span class="n">tauj</span> <span class="o">=</span> <span class="n">medy</span>
                <span class="k">for</span> <span class="n">ipar</span><span class="p">,</span> <span class="n">par_coeff</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">links_coeffs</span><span class="p">[</span><span class="n">medy</span><span class="p">[</span><span class="mi">0</span><span class="p">]]):</span>
                    <span class="n">par</span><span class="p">,</span> <span class="n">coeff</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">par_coeff</span>
                    <span class="n">i</span><span class="p">,</span> <span class="n">taui</span> <span class="o">=</span> <span class="n">par</span>
                    <span class="n">taui_shifted</span> <span class="o">=</span> <span class="n">taui</span> <span class="o">+</span> <span class="n">tauj</span>
                    <span class="n">max_lag</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="nb">abs</span><span class="p">(</span><span class="n">par</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="n">max_lag</span><span class="p">)</span>
                    <span class="n">coeffs</span><span class="p">[</span><span class="n">medy</span><span class="p">][(</span><span class="n">i</span><span class="p">,</span> <span class="n">taui_shifted</span><span class="p">)]</span> <span class="o">=</span> <span class="n">coeff</span> <span class="c1">#self.fit_results[j][(j, 0)][&#39;model&#39;].coef_[ipar]</span>

            <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">tau_max</span> <span class="o">=</span> <span class="n">max_lag</span>
            <span class="c1"># print(coeffs)</span>

        <span class="k">elif</span> <span class="n">method</span> <span class="o">==</span> <span class="s1">&#39;optimal&#39;</span><span class="p">:</span>
            <span class="c1"># all_parents = {}</span>
            <span class="n">coeffs</span> <span class="o">=</span> <span class="p">{}</span>
            <span class="k">for</span> <span class="n">medy</span> <span class="ow">in</span> <span class="p">[</span><span class="n">med</span> <span class="k">for</span> <span class="n">med</span> <span class="ow">in</span> <span class="n">mediators</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="n">y</span> <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">listY</span><span class="p">]:</span>
                <span class="n">coeffs</span><span class="p">[</span><span class="n">medy</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
                <span class="n">mediator_parents</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_all_parents</span><span class="p">([</span><span class="n">medy</span><span class="p">])</span><span class="o">.</span><span class="n">intersection</span><span class="p">(</span><span class="n">mediators</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">X</span><span class="p">)</span><span class="o">.</span><span class="n">union</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="p">))</span> <span class="o">-</span> <span class="nb">set</span><span class="p">([</span><span class="n">medy</span><span class="p">])</span>
                <span class="n">all_parents</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_all_parents</span><span class="p">([</span><span class="n">medy</span><span class="p">])</span> <span class="o">-</span> <span class="nb">set</span><span class="p">([</span><span class="n">medy</span><span class="p">])</span>
                <span class="k">for</span> <span class="n">par</span> <span class="ow">in</span> <span class="n">mediator_parents</span><span class="p">:</span>
                    <span class="n">Sprime</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">all_parents</span><span class="p">)</span> <span class="o">-</span> <span class="nb">set</span><span class="p">([</span><span class="n">par</span><span class="p">,</span> <span class="n">medy</span><span class="p">])</span>
                    <span class="n">causal_effects</span> <span class="o">=</span> <span class="n">CausalEffects</span><span class="p">(</span><span class="n">graph</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">graph</span><span class="p">,</span> 
                                        <span class="n">X</span><span class="o">=</span><span class="p">[</span><span class="n">par</span><span class="p">],</span> <span class="n">Y</span><span class="o">=</span><span class="p">[</span><span class="n">medy</span><span class="p">],</span> <span class="n">S</span><span class="o">=</span><span class="n">Sprime</span><span class="p">,</span>
                                        <span class="n">graph_type</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">graph_type</span><span class="p">,</span>
                                        <span class="n">check_SM_overlap</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
                                        <span class="p">)</span>
                    <span class="n">oset</span> <span class="o">=</span> <span class="n">causal_effects</span><span class="o">.</span><span class="n">get_optimal_set</span><span class="p">()</span>
                    <span class="c1"># print(medy, par, list(set(all_parents)), oset)</span>
                    <span class="k">if</span> <span class="n">oset</span> <span class="ow">is</span> <span class="kc">False</span><span class="p">:</span>
                        <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Not identifiable via Wright&#39;s method.&quot;</span><span class="p">)</span>
                    <span class="n">fit_res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">get_general_fitted_model</span><span class="p">(</span>
                        <span class="n">Y</span><span class="o">=</span><span class="p">[</span><span class="n">medy</span><span class="p">],</span> <span class="n">X</span><span class="o">=</span><span class="p">[</span><span class="n">par</span><span class="p">],</span> <span class="n">Z</span><span class="o">=</span><span class="n">oset</span><span class="p">,</span>
                        <span class="n">tau_max</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span><span class="p">,</span>
                        <span class="n">cut_off</span><span class="o">=</span><span class="s1">&#39;tau_max&#39;</span><span class="p">,</span>
                        <span class="n">return_data</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
                    <span class="n">coeffs</span><span class="p">[</span><span class="n">medy</span><span class="p">][</span><span class="n">par</span><span class="p">]</span> <span class="o">=</span> <span class="n">fit_res</span><span class="p">[</span><span class="s1">&#39;model&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">coef_</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>

        <span class="k">elif</span> <span class="n">method</span> <span class="o">==</span> <span class="s1">&#39;parents&#39;</span><span class="p">:</span>
            <span class="n">coeffs</span> <span class="o">=</span> <span class="p">{}</span>
            <span class="k">for</span> <span class="n">medy</span> <span class="ow">in</span> <span class="p">[</span><span class="n">med</span> <span class="k">for</span> <span class="n">med</span> <span class="ow">in</span> <span class="n">mediators</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="n">y</span> <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">listY</span><span class="p">]:</span>
                <span class="n">coeffs</span><span class="p">[</span><span class="n">medy</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
                <span class="c1"># mediator_parents = self._get_all_parents([medy]).intersection(mediators.union(self.X)) - set([medy])</span>
                <span class="n">all_parents</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_all_parents</span><span class="p">([</span><span class="n">medy</span><span class="p">])</span> <span class="o">-</span> <span class="nb">set</span><span class="p">([</span><span class="n">medy</span><span class="p">])</span>
                <span class="k">if</span> <span class="s1">&#39;dag&#39;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">graph_type</span><span class="p">:</span>
                    <span class="n">spouses</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_all_spouses</span><span class="p">([</span><span class="n">medy</span><span class="p">])</span> <span class="o">-</span> <span class="nb">set</span><span class="p">([</span><span class="n">medy</span><span class="p">])</span>
                    <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">spouses</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
                        <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;method == &#39;parents&#39; only possible for &quot;</span>
                                         <span class="s2">&quot;causal paths without adjacent bi-directed links!&quot;</span><span class="p">)</span>

                <span class="c1"># print(j, all_parents[j])</span>
                <span class="c1"># if len(all_parents[j]) &gt; 0:</span>
                <span class="c1"># print(medy, list(all_parents))</span>
                <span class="n">fit_res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">get_general_fitted_model</span><span class="p">(</span>
                    <span class="n">Y</span><span class="o">=</span><span class="p">[</span><span class="n">medy</span><span class="p">],</span> <span class="n">X</span><span class="o">=</span><span class="nb">list</span><span class="p">(</span><span class="n">all_parents</span><span class="p">),</span> <span class="n">Z</span><span class="o">=</span><span class="p">[],</span>
                    <span class="n">conditions</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
                    <span class="n">tau_max</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">tau_max</span><span class="p">,</span>
                    <span class="n">cut_off</span><span class="o">=</span><span class="s1">&#39;tau_max&#39;</span><span class="p">,</span>
                    <span class="n">return_data</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>

                <span class="k">for</span> <span class="n">ipar</span><span class="p">,</span> <span class="n">par</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">all_parents</span><span class="p">)):</span>
                    <span class="c1"># print(par, fit_res[&#39;model&#39;].coef_)</span>
                    <span class="n">coeffs</span><span class="p">[</span><span class="n">medy</span><span class="p">][</span><span class="n">par</span><span class="p">]</span> <span class="o">=</span> <span class="n">fit_res</span><span class="p">[</span><span class="s1">&#39;model&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">coef_</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">ipar</span><span class="p">]</span>

        <span class="k">else</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;method must be &#39;optimal&#39;, &#39;links_coeffs&#39;, or &#39;parents&#39;.&quot;</span><span class="p">)</span>
        
        <span class="c1"># Effect is sum over products over all path coefficients</span>
        <span class="c1"># from x in X to y in Y</span>
        <span class="n">effect</span> <span class="o">=</span> <span class="p">{}</span>
        <span class="k">for</span> <span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="ow">in</span> <span class="n">itertools</span><span class="o">.</span><span class="n">product</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">listX</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">listY</span><span class="p">):</span>
            <span class="n">effect</span><span class="p">[(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)]</span> <span class="o">=</span> <span class="mf">0.</span>
            <span class="k">for</span> <span class="n">causal_path</span> <span class="ow">in</span> <span class="n">causal_paths</span><span class="p">[</span><span class="n">x</span><span class="p">][</span><span class="n">y</span><span class="p">]:</span>
                <span class="n">effect_here</span> <span class="o">=</span> <span class="mf">1.</span>
                <span class="c1"># print(x, y, causal_path)</span>
                <span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">node</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">causal_path</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]):</span>
                    <span class="n">i</span><span class="p">,</span> <span class="n">taui</span> <span class="o">=</span> <span class="n">node</span>
                    <span class="n">j</span><span class="p">,</span> <span class="n">tauj</span> <span class="o">=</span> <span class="n">causal_path</span><span class="p">[</span><span class="n">index</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span>
                    <span class="c1"># tau_ij = abs(tauj - taui)</span>
                    <span class="c1"># print((j, tauj), (i, taui))</span>
                    <span class="n">effect_here</span> <span class="o">*=</span> <span class="n">coeffs</span><span class="p">[(</span><span class="n">j</span><span class="p">,</span> <span class="n">tauj</span><span class="p">)][(</span><span class="n">i</span><span class="p">,</span> <span class="n">taui</span><span class="p">)]</span>

                <span class="n">effect</span><span class="p">[(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)]</span> <span class="o">+=</span> <span class="n">effect_here</span>
               
        <span class="c1"># Make fitted coefficients available as attribute</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">coeffs</span> <span class="o">=</span> <span class="n">coeffs</span>

        <span class="c1"># Modify and overwrite variables in self.model</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">Y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">listY</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">X</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">listX</span>  
        <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">Z</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">conditions</span> <span class="o">=</span> <span class="p">[]</span> 
        <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">cut_off</span> <span class="o">=</span> <span class="s1">&#39;tau_max&#39;</span> <span class="c1"># &#39;max_lag_or_tau_max&#39;</span>

        <span class="k">class</span> <span class="nc">dummy_fit_class</span><span class="p">():</span>
            <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">y_here</span><span class="p">,</span> <span class="n">listX_here</span><span class="p">,</span> <span class="n">effect_here</span><span class="p">):</span>
                <span class="n">dim</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">listX_here</span><span class="p">)</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">coeff_array</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">effect_here</span><span class="p">[(</span><span class="n">x</span><span class="p">,</span> <span class="n">y_here</span><span class="p">)]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">listX_here</span><span class="p">])</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
            <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
                <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">coeff_array</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>

        <span class="n">fit_results</span> <span class="o">=</span> <span class="p">{}</span>
        <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">listY</span><span class="p">:</span>
            <span class="n">fit_results</span><span class="p">[</span><span class="n">y</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
            <span class="n">fit_results</span><span class="p">[</span><span class="n">y</span><span class="p">][</span><span class="s1">&#39;model&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">dummy_fit_class</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">listX</span><span class="p">,</span> <span class="n">effect</span><span class="p">)</span>
            <span class="n">fit_results</span><span class="p">[</span><span class="n">y</span><span class="p">][</span><span class="s1">&#39;data_transform&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">data_transform</span><span class="p">)</span>

        <span class="c1"># self.effect = effect</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">fit_results</span> <span class="o">=</span> <span class="n">fit_results</span>
        <span class="k">return</span> <span class="bp">self</span></div>
 
<div class="viewcode-block" id="CausalEffects.predict_wright_effect"><a class="viewcode-back" href="../../index.html#tigramite.causal_effects.CausalEffects.predict_wright_effect">[docs]</a>    <span class="k">def</span> <span class="nf">predict_wright_effect</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> 
        <span class="n">intervention_data</span><span class="p">,</span> 
        <span class="n">pred_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
        <span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Predict linear effect of intervention with fitted Wright-model.</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        intervention_data : numpy array</span>
<span class="sd">            Numpy array of shape (time, len(X)) that contains the do(X) values.</span>
<span class="sd">        pred_params : dict, optional</span>
<span class="sd">            Optional parameters passed on to sklearn prediction function.</span>

<span class="sd">        Returns</span>
<span class="sd">        -------</span>
<span class="sd">        Results from prediction: an array of shape  (time, len(Y)).</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="n">lenX</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">listX</span><span class="p">)</span>
        <span class="n">lenY</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">listY</span><span class="p">)</span>

        <span class="k">if</span> <span class="n">intervention_data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">!=</span> <span class="n">lenX</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;intervention_data.shape[1] must be len(X).&quot;</span><span class="p">)</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">no_causal_path</span><span class="p">:</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
                <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;No causal path from X to Y exists.&quot;</span><span class="p">)</span>
            <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">intervention_data</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">Y</span><span class="p">)))</span>

        <span class="n">intervention_T</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">intervention_data</span><span class="o">.</span><span class="n">shape</span>


        <span class="n">predicted_array</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">intervention_T</span><span class="p">,</span> <span class="n">lenY</span><span class="p">))</span>
        <span class="n">pred_dict</span> <span class="o">=</span> <span class="p">{}</span>
        <span class="k">for</span> <span class="n">iy</span><span class="p">,</span> <span class="n">y</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">listY</span><span class="p">):</span>
            <span class="c1"># Print message</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
                <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">## Predicting target </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="nb">str</span><span class="p">(</span><span class="n">y</span><span class="p">))</span>
                <span class="k">if</span> <span class="n">pred_params</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
                    <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="n">pred_params</span><span class="p">):</span>
                        <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">%s</span><span class="s2"> = </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">pred_params</span><span class="p">[</span><span class="n">key</span><span class="p">]))</span>
            <span class="c1"># Default value for pred_params</span>
            <span class="k">if</span> <span class="n">pred_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
                <span class="n">pred_params</span> <span class="o">=</span> <span class="p">{}</span>
            <span class="c1"># Check this is a valid target</span>
            <span class="k">if</span> <span class="n">y</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">fit_results</span><span class="p">:</span>
                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;y = </span><span class="si">%s</span><span class="s2"> not yet fitted&quot;</span> <span class="o">%</span> <span class="nb">str</span><span class="p">(</span><span class="n">y</span><span class="p">))</span>

            <span class="c1"># data_transform is too complicated for Wright estimator</span>
            <span class="c1"># Transform the data if needed</span>
            <span class="c1"># fitted_data_transform = self.model.fit_results[y][&#39;fitted_data_transform&#39;]</span>
            <span class="c1"># if fitted_data_transform is not None:</span>
            <span class="c1">#     intervention_data = fitted_data_transform[&#39;X&#39;].transform(X=intervention_data)</span>

            <span class="c1"># Now iterate through interventions (and potentially S)</span>
            <span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">dox_vals</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">intervention_data</span><span class="p">):</span>
                <span class="c1"># Construct XZS-array</span>
                <span class="n">intervention_array</span> <span class="o">=</span> <span class="n">dox_vals</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">lenX</span><span class="p">)</span> 
                <span class="n">predictor_array</span> <span class="o">=</span> <span class="n">intervention_array</span>

                <span class="n">predicted_vals</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">fit_results</span><span class="p">[</span><span class="n">y</span><span class="p">][</span><span class="s1">&#39;model&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span>
                <span class="n">X</span><span class="o">=</span><span class="n">predictor_array</span><span class="p">,</span> <span class="o">**</span><span class="n">pred_params</span><span class="p">)</span>
                <span class="n">predicted_array</span><span class="p">[</span><span class="n">index</span><span class="p">,</span> <span class="n">iy</span><span class="p">]</span> <span class="o">=</span> <span class="n">predicted_vals</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>

                <span class="c1"># data_transform is too complicated for Wright estimator</span>
                <span class="c1"># if fitted_data_transform is not None:</span>
                <span class="c1">#     rescaled = fitted_data_transform[&#39;Y&#39;].inverse_transform(X=predicted_array[index, iy].reshape(-1, 1))</span>
                <span class="c1">#     predicted_array[index, iy] = rescaled.squeeze()</span>

        <span class="k">return</span> <span class="n">predicted_array</span></div>


<div class="viewcode-block" id="CausalEffects.fit_bootstrap_of"><a class="viewcode-back" href="../../index.html#tigramite.causal_effects.CausalEffects.fit_bootstrap_of">[docs]</a>    <span class="k">def</span> <span class="nf">fit_bootstrap_of</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">method</span><span class="p">,</span> <span class="n">method_args</span><span class="p">,</span> 
                        <span class="n">boot_samples</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span>
                        <span class="n">boot_blocklength</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
                        <span class="n">seed</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Runs chosen method on bootstrap samples drawn from DataFrame.</span>
<span class="sd">        </span>
<span class="sd">        Bootstraps for tau=0 are drawn from [max_lag, ..., T] and all lagged</span>
<span class="sd">        variables constructed in DataFrame.construct_array are consistently</span>
<span class="sd">        shifted with respect to this bootsrap sample to ensure that lagged</span>
<span class="sd">        relations in the bootstrap sample are preserved.</span>

<span class="sd">        This function fits the models, predict_bootstrap_of can then be used</span>
<span class="sd">        to get confidence intervals for the effect of interventions.</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        method : str</span>
<span class="sd">            Chosen method among valid functions in this class.</span>
<span class="sd">        method_args : dict</span>
<span class="sd">            Arguments passed to method.</span>
<span class="sd">        boot_samples : int</span>
<span class="sd">            Number of bootstrap samples to draw.</span>
<span class="sd">        boot_blocklength : int, optional (default: 1)</span>
<span class="sd">            Block length for block-bootstrap.</span>
<span class="sd">        seed : int, optional(default = None)</span>
<span class="sd">            Seed for RandomState (default_rng)</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="c1"># if dataframe.analysis_mode != &#39;single&#39;:</span>
        <span class="c1">#     raise ValueError(&quot;CausalEffects class currently only supports single &quot;</span>
        <span class="c1">#                      &quot;datasets.&quot;)</span>

        <span class="n">valid_methods</span> <span class="o">=</span> <span class="p">[</span><span class="s1">&#39;fit_total_effect&#39;</span><span class="p">,</span>
                         <span class="s1">&#39;fit_wright_effect&#39;</span><span class="p">,</span>
                          <span class="p">]</span>

        <span class="k">if</span> <span class="n">method</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">valid_methods</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;method must be one of </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="nb">str</span><span class="p">(</span><span class="n">valid_methods</span><span class="p">))</span>

        <span class="c1"># First call the method on the original dataframe </span>
        <span class="c1"># to make available adjustment set etc</span>
        <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">method</span><span class="p">)(</span><span class="o">**</span><span class="n">method_args</span><span class="p">)</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">original_model</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbosity</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
            <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">##</span><span class="se">\n</span><span class="s2">## Running Bootstrap of </span><span class="si">%s</span><span class="s2"> &quot;</span> <span class="o">%</span> <span class="n">method</span> <span class="o">+</span>
                  <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">##</span><span class="se">\n</span><span class="s2">&quot;</span> <span class="o">+</span>
                  <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">boot_samples = </span><span class="si">%s</span><span class="s2"> </span><span class="se">\n</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">boot_samples</span> <span class="o">+</span>
                  <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">boot_blocklength = </span><span class="si">%s</span><span class="s2"> </span><span class="se">\n</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">boot_blocklength</span>
                  <span class="p">)</span>

        <span class="n">method_args_bootstrap</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">method_args</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">bootstrap_results</span> <span class="o">=</span> <span class="p">{}</span>

        <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">boot_samples</span><span class="p">):</span>
            <span class="c1"># # Replace dataframe in method args by bootstrapped dataframe</span>
            <span class="c1"># method_args_bootstrap[&#39;dataframe&#39;].bootstrap = boot_draw</span>
            <span class="k">if</span> <span class="n">seed</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
                <span class="n">random_state</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">default_rng</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="n">random_state</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">default_rng</span><span class="p">(</span><span class="n">seed</span><span class="o">*</span><span class="n">boot_samples</span> <span class="o">+</span> <span class="n">b</span><span class="p">)</span>

            <span class="n">method_args_bootstrap</span><span class="p">[</span><span class="s1">&#39;dataframe&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">bootstrap</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;boot_blocklength&#39;</span><span class="p">:</span><span class="n">boot_blocklength</span><span class="p">,</span>
                                                            <span class="s1">&#39;random_state&#39;</span><span class="p">:</span><span class="n">random_state</span><span class="p">}</span>

            <span class="c1"># Call method and save fitted model</span>
            <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">method</span><span class="p">)(</span><span class="o">**</span><span class="n">method_args_bootstrap</span><span class="p">)</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">bootstrap_results</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="p">)</span>

        <span class="c1"># Reset model</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">original_model</span>

        <span class="k">return</span> <span class="bp">self</span></div>


<div class="viewcode-block" id="CausalEffects.predict_bootstrap_of"><a class="viewcode-back" href="../../index.html#tigramite.causal_effects.CausalEffects.predict_bootstrap_of">[docs]</a>    <span class="k">def</span> <span class="nf">predict_bootstrap_of</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">method</span><span class="p">,</span> <span class="n">method_args</span><span class="p">,</span> 
                        <span class="n">conf_lev</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span>
                        <span class="n">return_individual_bootstrap_results</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Predicts with fitted bootstraps.</span>

<span class="sd">        To be used after fitting with fit_bootstrap_of. Only uses the </span>
<span class="sd">        expected values of the predict function, not potential other output.</span>

<span class="sd">        Parameters</span>
<span class="sd">        ----------</span>
<span class="sd">        method : str</span>
<span class="sd">            Chosen method among valid functions in this class.</span>
<span class="sd">        method_args : dict</span>
<span class="sd">            Arguments passed to method.</span>
<span class="sd">        conf_lev : float, optional (default: 0.9)</span>
<span class="sd">            Two-sided confidence interval.</span>
<span class="sd">        return_individual_bootstrap_results : bool</span>
<span class="sd">            Returns the individual bootstrap predictions.</span>

<span class="sd">        Returns</span>
<span class="sd">        -------</span>
<span class="sd">        confidence_intervals : numpy array</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="n">valid_methods</span> <span class="o">=</span> <span class="p">[</span><span class="s1">&#39;predict_total_effect&#39;</span><span class="p">,</span>
                         <span class="s1">&#39;predict_wright_effect&#39;</span><span class="p">,</span>
                          <span class="p">]</span>

        <span class="k">if</span> <span class="n">method</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">valid_methods</span><span class="p">:</span>
            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;method must be one of </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="nb">str</span><span class="p">(</span><span class="n">valid_methods</span><span class="p">))</span>

        <span class="c1"># def get_vectorized_length(W):</span>
        <span class="c1">#     return sum([len(self.dataframe.vector_vars[w[0]]) for w in W])</span>

        <span class="n">lenX</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">listX</span><span class="p">)</span>
        <span class="n">lenS</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">listS</span><span class="p">)</span>
        <span class="n">lenY</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">listY</span><span class="p">)</span>

        <span class="n">intervention_T</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">method_args</span><span class="p">[</span><span class="s1">&#39;intervention_data&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span>

        <span class="n">boot_samples</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bootstrap_results</span><span class="p">)</span>
        <span class="c1"># bootstrap_predicted_array = np.zeros((boot_samples, intervention_T, lenY))</span>
        
        <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">boot_samples</span><span class="p">):</span> <span class="c1">#self.bootstrap_results.keys():</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bootstrap_results</span><span class="p">[</span><span class="n">b</span><span class="p">]</span>
            <span class="n">boot_effect</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">method</span><span class="p">)(</span><span class="o">**</span><span class="n">method_args</span><span class="p">)</span>

            <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">boot_effect</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
                <span class="n">boot_effect</span> <span class="o">=</span> <span class="n">boot_effect</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
            
            <span class="k">if</span> <span class="n">b</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
                <span class="n">bootstrap_predicted_array</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">boot_samples</span><span class="p">,</span> <span class="p">)</span> <span class="o">+</span> <span class="n">boot_effect</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> 
                                            <span class="n">dtype</span><span class="o">=</span><span class="n">boot_effect</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
            <span class="n">bootstrap_predicted_array</span><span class="p">[</span><span class="n">b</span><span class="p">]</span> <span class="o">=</span> <span class="n">boot_effect</span>

        <span class="c1"># Reset model</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">original_model</span>

        <span class="c1"># Confidence intervals for val_matrix; interval is two-sided</span>
        <span class="n">c_int</span> <span class="o">=</span> <span class="p">(</span><span class="mf">1.</span> <span class="o">-</span> <span class="p">(</span><span class="mf">1.</span> <span class="o">-</span> <span class="n">conf_lev</span><span class="p">)</span><span class="o">/</span><span class="mf">2.</span><span class="p">)</span>
        <span class="n">confidence_interval</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">percentile</span><span class="p">(</span>
                <span class="n">bootstrap_predicted_array</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
                <span class="n">q</span> <span class="o">=</span> <span class="p">[</span><span class="mi">100</span><span class="o">*</span><span class="p">(</span><span class="mf">1.</span> <span class="o">-</span> <span class="n">c_int</span><span class="p">),</span> <span class="mi">100</span><span class="o">*</span><span class="n">c_int</span><span class="p">])</span>   <span class="c1">#[:,:,0]</span>

        <span class="k">if</span> <span class="n">return_individual_bootstrap_results</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">bootstrap_predicted_array</span><span class="p">,</span> <span class="n">confidence_interval</span>

        <span class="k">return</span> <span class="n">confidence_interval</span></div>

<div class="viewcode-block" id="CausalEffects.get_dict_from_graph"><a class="viewcode-back" href="../../index.html#tigramite.causal_effects.CausalEffects.get_dict_from_graph">[docs]</a>    <span class="nd">@staticmethod</span>
    <span class="k">def</span> <span class="nf">get_dict_from_graph</span><span class="p">(</span><span class="n">graph</span><span class="p">,</span> <span class="n">parents_only</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Helper function to convert graph to dictionary of links.</span>

<span class="sd">        Parameters</span>
<span class="sd">        ---------</span>
<span class="sd">        graph : array of shape (N, N, tau_max+1)</span>
<span class="sd">            Matrix format of graph in string format.</span>

<span class="sd">        parents_only : bool</span>
<span class="sd">            Whether to only return parents (&#39;--&gt;&#39; in graph)</span>

<span class="sd">        Returns</span>
<span class="sd">        -------</span>
<span class="sd">        links : dict</span>
<span class="sd">            Dictionary of form {0:{(0, -1): o-o, ...}, 1:{...}, ...}.</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="n">N</span> <span class="o">=</span> <span class="n">graph</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>

        <span class="n">links</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">([(</span><span class="n">j</span><span class="p">,</span> <span class="p">{})</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">N</span><span class="p">)])</span>

        <span class="k">if</span> <span class="n">parents_only</span><span class="p">:</span>
            <span class="k">for</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">tau</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">graph</span><span class="o">==</span><span class="s1">&#39;--&gt;&#39;</span><span class="p">)):</span>
                <span class="n">links</span><span class="p">[</span><span class="n">j</span><span class="p">][(</span><span class="n">i</span><span class="p">,</span> <span class="o">-</span><span class="n">tau</span><span class="p">)]</span> <span class="o">=</span> <span class="n">graph</span><span class="p">[</span><span class="n">i</span><span class="p">,</span><span class="n">j</span><span class="p">,</span><span class="n">tau</span><span class="p">]</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">for</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">tau</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">graph</span><span class="o">!=</span><span class="s1">&#39;&#39;</span><span class="p">)):</span>
                <span class="n">links</span><span class="p">[</span><span class="n">j</span><span class="p">][(</span><span class="n">i</span><span class="p">,</span> <span class="o">-</span><span class="n">tau</span><span class="p">)]</span> <span class="o">=</span> <span class="n">graph</span><span class="p">[</span><span class="n">i</span><span class="p">,</span><span class="n">j</span><span class="p">,</span><span class="n">tau</span><span class="p">]</span>

        <span class="k">return</span> <span class="n">links</span></div>

<div class="viewcode-block" id="CausalEffects.get_graph_from_dict"><a class="viewcode-back" href="../../index.html#tigramite.causal_effects.CausalEffects.get_graph_from_dict">[docs]</a>    <span class="nd">@staticmethod</span>
    <span class="k">def</span> <span class="nf">get_graph_from_dict</span><span class="p">(</span><span class="n">links</span><span class="p">,</span> <span class="n">tau_max</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;Helper function to convert dictionary of links to graph array format.</span>

<span class="sd">        Parameters</span>
<span class="sd">        ---------</span>
<span class="sd">        links : dict</span>
<span class="sd">            Dictionary of form {0:[((0, -1), coeff, func), ...], 1:[...], ...}.</span>
<span class="sd">            Also format {0:[(0, -1), ...], 1:[...], ...} is allowed.</span>
<span class="sd">        tau_max : int or None</span>
<span class="sd">            Maximum lag. If None, the maximum lag in links is used.</span>

<span class="sd">        Returns</span>
<span class="sd">        -------</span>
<span class="sd">        graph : array of shape (N, N, tau_max+1)</span>
<span class="sd">            Matrix format of graph with 1 for true links and 0 else.</span>
<span class="sd">        &quot;&quot;&quot;</span>

        <span class="k">def</span> <span class="nf">_get_minmax_lag</span><span class="p">(</span><span class="n">links</span><span class="p">):</span>
<span class="w">            </span><span class="sd">&quot;&quot;&quot;Helper function to retrieve tau_min and tau_max from links.</span>
<span class="sd">            &quot;&quot;&quot;</span>

            <span class="n">N</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">links</span><span class="p">)</span>

            <span class="c1"># Get maximum time lag</span>
            <span class="n">min_lag</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">inf</span>
            <span class="n">max_lag</span> <span class="o">=</span> <span class="mi">0</span>
            <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">N</span><span class="p">):</span>
                <span class="k">for</span> <span class="n">link_props</span> <span class="ow">in</span> <span class="n">links</span><span class="p">[</span><span class="n">j</span><span class="p">]:</span>
                    <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">link_props</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">2</span><span class="p">:</span>
                        <span class="n">var</span><span class="p">,</span> <span class="n">lag</span> <span class="o">=</span> <span class="n">link_props</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
                        <span class="n">coeff</span> <span class="o">=</span> <span class="n">link_props</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
                        <span class="c1"># func = link_props[2]</span>
                        <span class="k">if</span> <span class="n">coeff</span> <span class="o">!=</span> <span class="mf">0.</span><span class="p">:</span>
                            <span class="n">min_lag</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">min_lag</span><span class="p">,</span> <span class="nb">abs</span><span class="p">(</span><span class="n">lag</span><span class="p">))</span>
                            <span class="n">max_lag</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">max_lag</span><span class="p">,</span> <span class="nb">abs</span><span class="p">(</span><span class="n">lag</span><span class="p">))</span>
                    <span class="k">else</span><span class="p">:</span>
                        <span class="n">var</span><span class="p">,</span> <span class="n">lag</span> <span class="o">=</span> <span class="n">link_props</span>
                        <span class="n">min_lag</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">min_lag</span><span class="p">,</span> <span class="nb">abs</span><span class="p">(</span><span class="n">lag</span><span class="p">))</span>
                        <span class="n">max_lag</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">max_lag</span><span class="p">,</span> <span class="nb">abs</span><span class="p">(</span><span class="n">lag</span><span class="p">))</span>   

            <span class="k">return</span> <span class="n">min_lag</span><span class="p">,</span> <span class="n">max_lag</span>

        <span class="n">N</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">links</span><span class="p">)</span>

        <span class="c1"># Get maximum time lag</span>
        <span class="n">min_lag</span><span class="p">,</span> <span class="n">max_lag</span> <span class="o">=</span> <span class="n">_get_minmax_lag</span><span class="p">(</span><span class="n">links</span><span class="p">)</span>

        <span class="c1"># Set maximum lag</span>
        <span class="k">if</span> <span class="n">tau_max</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">tau_max</span> <span class="o">=</span> <span class="n">max_lag</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">if</span> <span class="n">max_lag</span> <span class="o">&gt;</span> <span class="n">tau_max</span><span class="p">:</span>
                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;tau_max is smaller than maximum lag = </span><span class="si">%d</span><span class="s2"> &quot;</span>
                                 <span class="s2">&quot;found in links, use tau_max=None or larger &quot;</span>
                                 <span class="s2">&quot;value&quot;</span> <span class="o">%</span> <span class="n">max_lag</span><span class="p">)</span>

        <span class="n">graph</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">N</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">tau_max</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;&lt;U3&#39;</span><span class="p">)</span>
        <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">links</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
            <span class="k">for</span> <span class="n">link_props</span> <span class="ow">in</span> <span class="n">links</span><span class="p">[</span><span class="n">j</span><span class="p">]:</span>
                <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">link_props</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">2</span><span class="p">:</span>
                    <span class="n">var</span><span class="p">,</span> <span class="n">lag</span> <span class="o">=</span> <span class="n">link_props</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
                    <span class="n">coeff</span> <span class="o">=</span> <span class="n">link_props</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
                    <span class="k">if</span> <span class="n">coeff</span> <span class="o">!=</span> <span class="mf">0.</span><span class="p">:</span>
                        <span class="n">graph</span><span class="p">[</span><span class="n">var</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="nb">abs</span><span class="p">(</span><span class="n">lag</span><span class="p">)]</span> <span class="o">=</span> <span class="s2">&quot;--&gt;&quot;</span>
                        <span class="k">if</span> <span class="n">lag</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
                            <span class="n">graph</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="n">var</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&quot;&lt;--&quot;</span>
                <span class="k">else</span><span class="p">:</span>
                    <span class="n">var</span><span class="p">,</span> <span class="n">lag</span> <span class="o">=</span> <span class="n">link_props</span>
                    <span class="n">graph</span><span class="p">[</span><span class="n">var</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="nb">abs</span><span class="p">(</span><span class="n">lag</span><span class="p">)]</span> <span class="o">=</span> <span class="s2">&quot;--&gt;&quot;</span>
                    <span class="k">if</span> <span class="n">lag</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
                        <span class="n">graph</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="n">var</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&quot;&lt;--&quot;</span>

        <span class="k">return</span> <span class="n">graph</span></div></div>


<span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
    
    <span class="c1"># Consider some toy data</span>
    <span class="kn">import</span> <span class="nn">tigramite</span>
    <span class="kn">import</span> <span class="nn">tigramite.toymodels.structural_causal_processes</span> <span class="k">as</span> <span class="nn">toys</span>
    <span class="kn">import</span> <span class="nn">tigramite.data_processing</span> <span class="k">as</span> <span class="nn">pp</span>
    <span class="kn">import</span> <span class="nn">tigramite.plotting</span> <span class="k">as</span> <span class="nn">tp</span>
    <span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">plt</span>
    <span class="kn">import</span> <span class="nn">sys</span>

    <span class="kn">import</span> <span class="nn">sklearn</span>
    <span class="kn">from</span> <span class="nn">sklearn.linear_model</span> <span class="kn">import</span> <span class="n">LinearRegression</span><span class="p">,</span> <span class="n">LogisticRegression</span>
    <span class="kn">from</span> <span class="nn">sklearn.preprocessing</span> <span class="kn">import</span> <span class="n">StandardScaler</span>
    <span class="kn">from</span> <span class="nn">sklearn.neural_network</span> <span class="kn">import</span> <span class="n">MLPRegressor</span>


    <span class="c1"># def lin_f(x): return x</span>
    <span class="c1"># coeff = .5</span>
 
    <span class="c1"># links_coeffs = {0: [((0, -1), 0.5, lin_f)],</span>
    <span class="c1">#          1: [((1, -1), 0.5, lin_f), ((0, -1), 0.5, lin_f)],</span>
    <span class="c1">#          2: [((2, -1), 0.5, lin_f), ((1, 0), 0.5, lin_f)]</span>
    <span class="c1">#          }</span>
    <span class="c1"># T = 1000</span>
    <span class="c1"># data, nonstat = toys.structural_causal_process(</span>
    <span class="c1">#     links_coeffs, T=T, noises=None, seed=7)</span>
    <span class="c1"># dataframe = pp.DataFrame(data)</span>

    <span class="c1"># graph = CausalEffects.get_graph_from_dict(links_coeffs)</span>

    <span class="c1"># original_graph = np.array([[[&#39;&#39;, &#39;&#39;],</span>
    <span class="c1">#     [&#39;--&gt;&#39;, &#39;&#39;],</span>
    <span class="c1">#     [&#39;--&gt;&#39;, &#39;&#39;],</span>
    <span class="c1">#     [&#39;&#39;, &#39;&#39;]],</span>

    <span class="c1">#    [[&#39;&lt;--&#39;, &#39;&#39;],</span>
    <span class="c1">#     [&#39;&#39;, &#39;--&gt;&#39;],</span>
    <span class="c1">#     [&#39;--&gt;&#39;, &#39;&#39;],</span>
    <span class="c1">#     [&#39;--&gt;&#39;, &#39;&#39;]],</span>

    <span class="c1">#    [[&#39;&lt;--&#39;, &#39;&#39;],</span>
    <span class="c1">#     [&#39;&lt;--&#39;, &#39;&#39;],</span>
    <span class="c1">#     [&#39;&#39;, &#39;--&gt;&#39;],</span>
    <span class="c1">#     [&#39;--&gt;&#39;, &#39;&#39;]],</span>

    <span class="c1">#    [[&#39;&#39;, &#39;&#39;],</span>
    <span class="c1">#     [&#39;&lt;--&#39;, &#39;&#39;],</span>
    <span class="c1">#     [&#39;&lt;--&#39;, &#39;&#39;],</span>
    <span class="c1">#     [&#39;&#39;, &#39;--&gt;&#39;]]], dtype=&#39;&lt;U3&#39;)</span>
    <span class="c1"># graph = np.copy(original_graph)</span>

    <span class="c1"># # Add T &lt;-&gt; Reco and T </span>
    <span class="c1"># graph[2,3,0] = &#39;+-&gt;&#39; ; graph[3,2,0] = &#39;&lt;-+&#39;</span>
    <span class="c1"># graph[1,3,1] = &#39;&lt;-&gt;&#39; #; graph[2,1,0] = &#39;&lt;--&#39;</span>

    <span class="c1"># added = np.zeros((4, 4, 1), dtype=&#39;&lt;U3&#39;)</span>
    <span class="c1"># added[:] = &quot;&quot;</span>
    <span class="c1"># graph = np.append(graph, added , axis=2)</span>


    <span class="c1"># X = [(1, 0)]</span>
    <span class="c1"># Y = [(3, 0)]</span>

    <span class="c1"># # # Initialize class as `stationary_dag`</span>
    <span class="c1"># causal_effects = CausalEffects(graph, graph_type=&#39;stationary_admg&#39;, </span>
    <span class="c1">#                             X=X, Y=Y, S=None, </span>
    <span class="c1">#                             hidden_variables=None, </span>
    <span class="c1">#                             verbosity=0)</span>

    <span class="c1"># print(causal_effects.get_optimal_set())</span>

    <span class="c1"># tp.plot_time_series_graph(</span>
    <span class="c1">#     graph = graph,</span>
    <span class="c1">#     save_name=&#39;Example_graph_in.pdf&#39;,</span>
    <span class="c1">#     # special_nodes=special_nodes,</span>
    <span class="c1">#     # var_names=var_names,</span>
    <span class="c1">#     figsize=(6, 4),</span>
    <span class="c1">#     )</span>

    <span class="c1"># tp.plot_time_series_graph(</span>
    <span class="c1">#     graph = causal_effects.graph,</span>
    <span class="c1">#     save_name=&#39;Example_graph_out.pdf&#39;,</span>
    <span class="c1">#     # special_nodes=special_nodes,</span>
    <span class="c1">#     # var_names=var_names,</span>
    <span class="c1">#     figsize=(6, 4),</span>
    <span class="c1">#     )</span>

    <span class="c1"># causal_effects.fit_wright_effect(dataframe=dataframe, </span>
    <span class="c1">#                         # links_coeffs = links_coeffs,</span>
    <span class="c1">#                         # mediation = [(1, 0), (1, -1), (1, -2)]</span>
    <span class="c1">#                         )</span>

    <span class="c1"># intervention_data = 1.*np.ones((1, 1))</span>
    <span class="c1"># y1 = causal_effects.predict_wright_effect( </span>
    <span class="c1">#         intervention_data=intervention_data,</span>
    <span class="c1">#         )</span>

    <span class="c1"># intervention_data = 0.*np.ones((1, 1))</span>
    <span class="c1"># y2 = causal_effects.predict_wright_effect( </span>
    <span class="c1">#         intervention_data=intervention_data,</span>
    <span class="c1">#         )</span>

    <span class="c1"># beta = (y1 - y2)</span>
    <span class="c1"># print(&quot;Causal effect is %.5f&quot; %(beta))</span>

    <span class="c1"># tp.plot_time_series_graph(</span>
    <span class="c1">#     graph = causal_effects.graph,</span>
    <span class="c1">#     save_name=&#39;Example_graph.pdf&#39;,</span>
    <span class="c1">#     # special_nodes=special_nodes,</span>
    <span class="c1">#     var_names=var_names,</span>
    <span class="c1">#     figsize=(8, 4),</span>
    <span class="c1">#     )</span>

    <span class="n">T</span> <span class="o">=</span> <span class="mi">10000</span>
    <span class="k">def</span> <span class="nf">lin_f</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="k">return</span> <span class="n">x</span>

    <span class="n">auto_coeff</span> <span class="o">=</span> <span class="mf">0.</span>
    <span class="n">coeff</span> <span class="o">=</span> <span class="mf">2.</span>

    <span class="n">links</span> <span class="o">=</span> <span class="p">{</span>
            <span class="mi">0</span><span class="p">:</span> <span class="p">[((</span><span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">auto_coeff</span><span class="p">,</span> <span class="n">lin_f</span><span class="p">)],</span> 
            <span class="mi">1</span><span class="p">:</span> <span class="p">[((</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">auto_coeff</span><span class="p">,</span> <span class="n">lin_f</span><span class="p">)],</span> 
            <span class="mi">2</span><span class="p">:</span> <span class="p">[((</span><span class="mi">2</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">auto_coeff</span><span class="p">,</span> <span class="n">lin_f</span><span class="p">),</span> <span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">coeff</span><span class="p">,</span> <span class="n">lin_f</span><span class="p">)],</span>
            <span class="mi">3</span><span class="p">:</span> <span class="p">[((</span><span class="mi">3</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">auto_coeff</span><span class="p">,</span> <span class="n">lin_f</span><span class="p">)],</span> 
            <span class="p">}</span>
    <span class="n">data</span><span class="p">,</span> <span class="n">nonstat</span> <span class="o">=</span> <span class="n">toys</span><span class="o">.</span><span class="n">structural_causal_process</span><span class="p">(</span><span class="n">links</span><span class="p">,</span> <span class="n">T</span><span class="o">=</span><span class="n">T</span><span class="p">,</span> 
                                <span class="n">noises</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">7</span><span class="p">)</span>


    <span class="c1"># # Create some missing values</span>
    <span class="c1"># data[-10:,:] = 999.</span>
    <span class="c1"># var_names = range(2)</span>

    <span class="n">dataframe</span> <span class="o">=</span> <span class="n">pp</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">data</span><span class="p">,</span>
                    <span class="n">vector_vars</span><span class="o">=</span><span class="p">{</span><span class="mi">0</span><span class="p">:[(</span><span class="mi">0</span><span class="p">,</span><span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span><span class="mi">0</span><span class="p">)],</span> 
                                 <span class="mi">1</span><span class="p">:[(</span><span class="mi">2</span><span class="p">,</span><span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span><span class="mi">0</span><span class="p">)]}</span>
                    <span class="p">)</span>

    <span class="c1"># # Construct expert knowledge graph from links here </span>
    <span class="n">aux_links</span> <span class="o">=</span> <span class="p">{</span><span class="mi">0</span><span class="p">:</span> <span class="p">[(</span><span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)],</span>
                 <span class="mi">1</span><span class="p">:</span> <span class="p">[(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">)],</span>
              <span class="p">}</span>
    <span class="c1"># # Use staticmethod to get graph</span>
    <span class="n">graph</span> <span class="o">=</span> <span class="n">CausalEffects</span><span class="o">.</span><span class="n">get_graph_from_dict</span><span class="p">(</span><span class="n">aux_links</span><span class="p">,</span> <span class="n">tau_max</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
    <span class="c1"># graph = np.array([[&#39;&#39;, &#39;--&gt;&#39;],</span>
    <span class="c1">#                   [&#39;&lt;--&#39;, &#39;&#39;]], dtype=&#39;&lt;U3&#39;)</span>
    
    <span class="c1"># # We are interested in lagged total effect of X on Y</span>
    <span class="n">X</span> <span class="o">=</span> <span class="p">[(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)]</span>
    <span class="n">Y</span> <span class="o">=</span> <span class="p">[(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)]</span>

    <span class="c1"># # Initialize class as `stationary_dag`</span>
    <span class="n">causal_effects</span> <span class="o">=</span> <span class="n">CausalEffects</span><span class="p">(</span><span class="n">graph</span><span class="p">,</span> <span class="n">graph_type</span><span class="o">=</span><span class="s1">&#39;stationary_dag&#39;</span><span class="p">,</span> 
                                <span class="n">X</span><span class="o">=</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="o">=</span><span class="n">Y</span><span class="p">,</span> <span class="n">S</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> 
                                <span class="n">hidden_variables</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> 
                                <span class="n">verbosity</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

    <span class="c1"># print(data)</span>
    <span class="c1"># # Optimal adjustment set (is used by default)</span>
    <span class="c1"># # print(causal_effects.get_optimal_set())</span>

    <span class="c1"># # # Fit causal effect model from observational data</span>
    <span class="n">causal_effects</span><span class="o">.</span><span class="n">fit_total_effect</span><span class="p">(</span>
        <span class="n">dataframe</span><span class="o">=</span><span class="n">dataframe</span><span class="p">,</span> 
        <span class="c1"># mask_type=&#39;y&#39;,</span>
        <span class="n">estimator</span><span class="o">=</span><span class="n">LinearRegression</span><span class="p">(),</span>
        <span class="p">)</span>

    <span class="c1"># # Fit causal effect model from observational data</span>
    <span class="c1"># causal_effects.fit_bootstrap_of(</span>
    <span class="c1">#     method=&#39;fit_total_effect&#39;,</span>
    <span class="c1">#     method_args={&#39;dataframe&#39;:dataframe,  </span>
    <span class="c1">#     # mask_type=&#39;y&#39;,</span>
    <span class="c1">#     &#39;estimator&#39;:LinearRegression()</span>
    <span class="c1">#     },</span>
    <span class="c1">#     boot_samples=3,</span>
    <span class="c1">#     boot_blocklength=1,</span>
    <span class="c1">#     seed=5</span>
    <span class="c1">#     )</span>


    <span class="c1"># Predict effect of interventions do(X=0.), ..., do(X=1.) in one go</span>
    <span class="n">lenX</span> <span class="o">=</span> <span class="mi">4</span> <span class="c1"># len(dataframe.vector_vars[X[0][0]])</span>
    <span class="n">dox_vals</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mf">0.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
    <span class="n">intervention_data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">dox_vals</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">dox_vals</span><span class="p">),</span> <span class="mi">1</span><span class="p">),</span> <span class="n">lenX</span><span class="p">)</span>

    <span class="n">intervention_data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mf">1.</span><span class="p">,</span> <span class="mf">0.</span><span class="p">,</span> <span class="mf">0.</span><span class="p">,</span> <span class="mf">0.</span><span class="p">]])</span>

    <span class="nb">print</span><span class="p">(</span><span class="n">intervention_data</span><span class="p">)</span>

    <span class="n">pred_Y</span> <span class="o">=</span> <span class="n">causal_effects</span><span class="o">.</span><span class="n">predict_total_effect</span><span class="p">(</span> 
            <span class="n">intervention_data</span><span class="o">=</span><span class="n">intervention_data</span><span class="p">)</span>
    <span class="nb">print</span><span class="p">(</span><span class="n">pred_Y</span><span class="p">,</span> <span class="n">pred_Y</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>





    <span class="c1"># # Predict effect of interventions do(X=0.), ..., do(X=1.) in one go</span>
    <span class="c1"># # dox_vals = np.array([1.]) #np.linspace(0., 1., 1)</span>
    <span class="c1"># intervention_data = np.tile(dox_vals.reshape(len(dox_vals), 1), len(X))</span>
    <span class="c1"># conf = causal_effects.predict_bootstrap_of(</span>
    <span class="c1">#     method=&#39;predict_total_effect&#39;,</span>
    <span class="c1">#     method_args={&#39;intervention_data&#39;:intervention_data})</span>
    <span class="c1"># print(conf, conf.shape)</span>



    <span class="c1"># # # Predict effect of interventions do(X=0.), ..., do(X=1.) in one go</span>
    <span class="c1"># # dox_vals = np.array([1.]) #np.linspace(0., 1., 1)</span>
    <span class="c1"># # intervention_data = dox_vals.reshape(len(dox_vals), len(X))</span>
    <span class="c1"># # pred_Y = causal_effects.predict_total_effect( </span>
    <span class="c1"># #         intervention_data=intervention_data)</span>
    <span class="c1"># # print(pred_Y)</span>



    <span class="c1"># # Fit causal effect model from observational data</span>
    <span class="c1"># causal_effects.fit_wright_effect(</span>
    <span class="c1">#     dataframe=dataframe, </span>
    <span class="c1">#     # mask_type=&#39;y&#39;,</span>
    <span class="c1">#     # estimator=LinearRegression(),</span>
    <span class="c1">#     # data_transform=StandardScaler(),</span>
    <span class="c1">#     )</span>

    <span class="c1"># # # Predict effect of interventions do(X=0.), ..., do(X=1.) in one go</span>
    <span class="c1"># dox_vals = np.linspace(0., 1., 5)</span>
    <span class="c1"># intervention_data = dox_vals.reshape(len(dox_vals), len(X))</span>
    <span class="c1"># pred_Y = causal_effects.predict_wright_effect( </span>
    <span class="c1">#         intervention_data=intervention_data)</span>
    <span class="c1"># print(pred_Y)</span>
</pre></div>

          </div>
          
        </div>
      </div>
      <div class="sphinxsidebar" role="navigation" aria-label="main navigation">
        <div class="sphinxsidebarwrapper">
<h1 class="logo"><a href="../../index.html">Tigramite</a></h1>








<h3>Navigation</h3>

<div class="relations">
<h3>Related Topics</h3>
<ul>
  <li><a href="../../index.html">Documentation overview</a><ul>
  <li><a href="../index.html">Module code</a><ul>
  </ul></li>
  </ul></li>
</ul>
</div>
<div id="searchbox" style="display: none" role="search">
  <h3 id="searchlabel">Quick search</h3>
    <div class="searchformwrapper">
    <form class="search" action="../../search.html" method="get">
      <input type="text" name="q" aria-labelledby="searchlabel" autocomplete="off" autocorrect="off" autocapitalize="off" spellcheck="false"/>
      <input type="submit" value="Go" />
    </form>
    </div>
</div>
<script>document.getElementById('searchbox').style.display = "block"</script>








        </div>
      </div>
      <div class="clearer"></div>
    </div>
    <div class="footer">
      &copy;2023, Jakob Runge.
      
      |
      Powered by <a href="http://sphinx-doc.org/">Sphinx 5.0.2</a>
      &amp; <a href="https://github.com/bitprophet/alabaster">Alabaster 0.7.12</a>
      
    </div>

    

    
  </body>
</html>